protect DB object and check module loaded

This commit is contained in:
Elijah Duffy
2025-05-30 16:54:23 -07:00
parent 4faae52a10
commit 2a43af1c3f

27
db.go
View File

@@ -23,8 +23,8 @@ var (
// ErrNoRows is returned when a query returns no rows. // ErrNoRows is returned when a query returns no rows.
ErrNoRows = qrm.ErrNoRows ErrNoRows = qrm.ErrNoRows
// SQLDB stores the current SQL database handle. // sqlDB stores the current SQL database handle.
SQLDB *sql.DB sqlDB *sql.DB
// config stores the database connection configuration. // config stores the database connection configuration.
config DBConfig config DBConfig
@@ -33,6 +33,15 @@ var (
dbModule *app.Module dbModule *app.Module
) )
// SQLO returns the current SQL database handle.
func SQLO() *sql.DB {
dbModule.RequireLoaded() // ensure the module is loaded before accessing the database
if sqlDB == nil {
panic("SQL database not initialized")
}
return sqlDB
}
// ModuleDB returns the database module with the provided configuration. // ModuleDB returns the database module with the provided configuration.
func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module {
if dbModule != nil { if dbModule != nil {
@@ -56,7 +65,7 @@ func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module {
// setupDB connects to the MySQL DB and initializes the goose migration provider. // setupDB connects to the MySQL DB and initializes the goose migration provider.
// If auto migrations are enabled in configuration, latest migrations are applied. // If auto migrations are enabled in configuration, latest migrations are applied.
func setupDB() error { func setupDB() error {
if SQLDB != nil && SQLDB.Ping() == nil { if sqlDB != nil && sqlDB.Ping() == nil {
dbModule.Logger().Warn("Database connection already established") dbModule.Logger().Warn("Database connection already established")
return nil return nil
} }
@@ -65,20 +74,20 @@ func setupDB() error {
db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
config.User, config.Password, config.Path, config.Name) config.User, config.Password, config.Path, config.Name)
SQLDB, err = sql.Open("mysql", db_path) sqlDB, err = sql.Open("mysql", db_path)
if err != nil { if err != nil {
dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err)
return err return err
} }
if err := SQLDB.Ping(); err != nil { if err := sqlDB.Ping(); err != nil {
dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err)
return err return err
} }
SQLDB.SetMaxOpenConns(config.MaxConn) sqlDB.SetMaxOpenConns(config.MaxConn)
stats := SQLDB.Stats() stats := sqlDB.Stats()
dbModule.Logger().Info( dbModule.Logger().Info(
"Connected to SQL database", "Connected to SQL database",
"user", config.User, "user", config.User,
@@ -97,11 +106,11 @@ func setupDB() error {
// teardownDB closes the database connection. // teardownDB closes the database connection.
func teardownDB() { func teardownDB() {
if SQLDB == nil { if sqlDB == nil {
return return
} }
if err := SQLDB.Close(); err != nil { if err := sqlDB.Close(); err != nil {
dbModule.Logger().Error("Couldn't close database", "err", err) dbModule.Logger().Error("Couldn't close database", "err", err)
} }
dbModule.Logger().Info("Closed database connection") dbModule.Logger().Info("Closed database connection")