diff --git a/dbx.go b/dbx.go index 689d526..65b7b69 100644 --- a/dbx.go +++ b/dbx.go @@ -60,10 +60,8 @@ type Statement interface { ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) } -var ( - // ErrNoRows is returned when a query returns no rows. - ErrNoRows = qrm.ErrNoRows - +// dbState stores package-level state for the database connection. +type dbState struct { // sqlDB stores the current SQL database handle. sqlDB *sql.DB @@ -75,28 +73,39 @@ var ( // debugLog indicates whether debug logging is enabled. debugLog bool +} + +var ( + // ErrNoRows is returned when a query returns no rows. + ErrNoRows = qrm.ErrNoRows + + state = dbState{} ) // SQLO returns the current SQL database handle. func SQLO() *sql.DB { dbxshared.DBModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database - if sqlDB == nil { + if state.sqlDB == nil { panic("SQL database not initialized") } - return sqlDB + return state.sqlDB } // ModuleDB returns the database module with the provided configuration. -func ModuleDB(dialect Dialect, cfg *DBConfig, forceDebugLog bool) *app.Module { +// dialect specifies the SQL dialect to use (e.g., DialectPostgres, DialectMySQL). +// config specifies the database connection configuration. +// forceDebugLog forces debug logging to be enabled regardless of the config setting. +func ModuleDB(dialect Dialect, config *DBConfig, forceDebugLog bool) *app.Module { if dbxshared.DBModule != nil { panic("ModuleDB initialized multiple times") } - if cfg == nil { + if config == nil { panic("ModuleDB requires a non-nil DBConfig") } - config = cfg // store configuration at package level - debugLog = cfg.DebugLog || forceDebugLog // force debug logging if requested + state.config = config // store configuration at package level + state.dialect = dialect // store dialect at package level + state.debugLog = config.DebugLog || forceDebugLog // force debug logging if requested dbxshared.DBModule = app.NewModule(ModuleDBName, app.ModuleOpts{ Setup: setupDB, @@ -220,46 +229,46 @@ func UpdateReturning[T any](sqlo Queryable, stmt Statement) (*T, error) { // setupDB connects to the database. func setupDB(m *app.Module) error { - if sqlDB != nil && sqlDB.Ping() == nil { + if state.sqlDB != nil && state.sqlDB.Ping() == nil { m.Logger().Warn("Database connection already established") return nil } logArgs := []any{ - "user", config.User, - "name", config.Name, - "uri", config.URI, - "dialect", dialect, + "user", state.config.User, + "name", state.config.Name, + "uri", state.config.URI, + "dialect", state.dialect, } var err error - sqlDB, err = sql.Open(string(dialect), config.ConnectionString(dialect)) + state.sqlDB, err = sql.Open(string(state.dialect), state.config.ConnectionString(state.dialect)) if err != nil { logArgs = append(logArgs, "err", err) m.Logger().Error("Couldn't open SQL database", logArgs...) return err } - if err := sqlDB.Ping(); err != nil { + if err := state.sqlDB.Ping(); err != nil { logArgs = append(logArgs, "err", err) m.Logger().Error("Couldn't ping SQL database", logArgs...) return err } - sqlDB.SetMaxOpenConns(config.MaxConn) + state.sqlDB.SetMaxOpenConns(state.config.MaxConn) - stats := sqlDB.Stats() + stats := state.sqlDB.Stats() m.Logger().Info( "Connected to SQL database", - "user", config.User, - "name", config.Name, - "uri", config.URI, + "user", state.config.User, + "name", state.config.Name, + "uri", state.config.URI, "maxConnections", stats.MaxOpenConnections, "currConnections", stats.OpenConnections, ) - if debugLog { - dbxshared.InitLogger(dialect) // initialize the logger for the specified dialect + if state.debugLog { + dbxshared.InitLogger(state.dialect) // initialize the logger for the specified dialect } return nil @@ -267,11 +276,11 @@ func setupDB(m *app.Module) error { // teardownDB closes the database connection. func teardownDB(m *app.Module) error { - if sqlDB == nil { + if state.sqlDB == nil { return nil } - if err := sqlDB.Close(); err != nil { + if err := state.sqlDB.Close(); err != nil { m.Logger().Error("Couldn't close database", "err", err) return err }