From 2a43af1c3ffe9094125a0bb37a9501f62253dc46 Mon Sep 17 00:00:00 2001 From: Elijah Duffy Date: Fri, 30 May 2025 16:54:23 -0700 Subject: [PATCH] protect DB object and check module loaded --- db.go | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/db.go b/db.go index 91be634..2f4e193 100644 --- a/db.go +++ b/db.go @@ -23,8 +23,8 @@ var ( // ErrNoRows is returned when a query returns no rows. ErrNoRows = qrm.ErrNoRows - // SQLDB stores the current SQL database handle. - SQLDB *sql.DB + // sqlDB stores the current SQL database handle. + sqlDB *sql.DB // config stores the database connection configuration. config DBConfig @@ -33,6 +33,15 @@ var ( dbModule *app.Module ) +// SQLO returns the current SQL database handle. +func SQLO() *sql.DB { + dbModule.RequireLoaded() // ensure the module is loaded before accessing the database + if sqlDB == nil { + panic("SQL database not initialized") + } + return sqlDB +} + // ModuleDB returns the database module with the provided configuration. func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { if dbModule != nil { @@ -56,7 +65,7 @@ func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { // setupDB connects to the MySQL DB and initializes the goose migration provider. // If auto migrations are enabled in configuration, latest migrations are applied. func setupDB() error { - if SQLDB != nil && SQLDB.Ping() == nil { + if sqlDB != nil && sqlDB.Ping() == nil { dbModule.Logger().Warn("Database connection already established") return nil } @@ -65,20 +74,20 @@ func setupDB() error { db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Path, config.Name) - SQLDB, err = sql.Open("mysql", db_path) + sqlDB, err = sql.Open("mysql", db_path) if err != nil { dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) return err } - if err := SQLDB.Ping(); err != nil { + if err := sqlDB.Ping(); err != nil { dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) return err } - SQLDB.SetMaxOpenConns(config.MaxConn) + sqlDB.SetMaxOpenConns(config.MaxConn) - stats := SQLDB.Stats() + stats := sqlDB.Stats() dbModule.Logger().Info( "Connected to SQL database", "user", config.User, @@ -97,11 +106,11 @@ func setupDB() error { // teardownDB closes the database connection. func teardownDB() { - if SQLDB == nil { + if sqlDB == nil { return } - if err := SQLDB.Close(); err != nil { + if err := sqlDB.Close(); err != nil { dbModule.Logger().Error("Couldn't close database", "err", err) } dbModule.Logger().Info("Closed database connection")