From 32567471e13c10b731b581f00aecb1f16ed58ab6 Mon Sep 17 00:00:00 2001 From: Elijah Duffy Date: Thu, 12 Jun 2025 17:44:44 -0700 Subject: [PATCH] refactor with improved developer ux in mind - Most functions now live directly in the dbx package - dbxp and dbxm are now ONLY the few functions that cannot be shared --- dbx.go | 276 +++++++++++++++++- dbxm/mysql.go | 176 +---------- dbxp/pxg.go | 176 +---------- .../dbxshared/dbxshared.go | 10 +- internal/dbxshared/logger.go | 30 ++ internal/dbxshared/module.go | 5 + utility.go | 16 + 7 files changed, 332 insertions(+), 357 deletions(-) rename generic/generic.go => internal/dbxshared/dbxshared.go (60%) create mode 100644 internal/dbxshared/logger.go create mode 100644 internal/dbxshared/module.go diff --git a/dbx.go b/dbx.go index 906c41a..d4c0da4 100644 --- a/dbx.go +++ b/dbx.go @@ -1,14 +1,32 @@ package dbx import ( + "context" "database/sql" - "strings" + "errors" + "gitea.auvem.com/go-toolkit/app" + "gitea.auvem.com/go-toolkit/dbx/internal/dbxshared" "github.com/go-jet/jet/v2/qrm" ) -// ModuleDBName is the name of the database module. -const ModuleDBName = "database" +const ( + // ModuleDBName is the name of the database module. + ModuleDBName = "database" + + // DialectPostgres is the PostgreSQL dialect. + DialectPostgres Dialect = "postgres" + // DialectMySQL is the MySQL dialect. + DialectMySQL Dialect = "mysql" +) + +// Dialect is the SQL dialect used by the database connection. +type Dialect string + +// String implements the Stringer interface for Dialect. +func (d Dialect) String() string { + return string(d) +} // SQLOFunc is a function that returns a *sql.DB pointer. type SQLOFunc = func() *sql.DB @@ -34,16 +52,244 @@ type ExecutableTx interface { Begin() (*sql.Tx, error) } -// StringToFilter processes a string to be used as a filter in an SQL LIKE -// statement. It replaces all spaces with % and adds % to the beginning and -// end of the string. -func StringToFilter(str string) string { - // Remove any existing leading or trailing % characters - str = strings.Trim(str, "%") - - // Replace all spaces with % and add % to the beginning and end of the string - str = strings.ReplaceAll(str, " ", "%") - str = "%" + str + "%" - - return str +// Statement is a common Jet statement for all SQL operations. +type Statement interface { + Query(db qrm.Queryable, destination any) error + QueryContext(ctx context.Context, db qrm.Queryable, destination any) error + Exec(db qrm.Executable) (sql.Result, error) + ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) +} + +// SelectStatement is a Jet statement that can be executed to fetch rows from the database. +type SelectStatement interface { + Statement + LIMIT(limit int64) SelectStatement +} + +var ( + // ErrNoRows is returned when a query returns no rows. + ErrNoRows = qrm.ErrNoRows + + // sqlDB stores the current SQL database handle. + sqlDB *sql.DB + + // config stores the database connection configuration. + config *DBConfig + + // dialect is the SQL dialect used by the database connection. + dialect Dialect + + // debugLog indicates whether debug logging is enabled. + debugLog bool +) + +// SQLO returns the current SQL database handle. +func SQLO() *sql.DB { + dbxshared.DBModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database + if sqlDB == nil { + panic("SQL database not initialized") + } + return sqlDB +} + +// ModuleDB returns the existing database module, or panics if it has not +// been initialized yet. +func ModuleDB() *app.Module { + if dbxshared.DBModule == nil { + panic("ModuleDB not initialized yet") + } + return dbxshared.DBModule +} + +// InitModuleDB returns the database module with the provided configuration. +func InitModuleDB(dialect Dialect, cfg *DBConfig, forceDebugLog bool) *app.Module { + if dbxshared.DBModule != nil { + panic("ModuleDB initialized multiple times") + } + if cfg == nil { + panic("ModuleDB requires a non-nil DBConfig") + } + + config = cfg // store configuration at package level + debugLog = cfg.DebugLog || forceDebugLog // force debug logging if requested + + dbxshared.DBModule = app.NewModule(ModuleDBName, app.ModuleOpts{ + Setup: setupDB, + Teardown: teardownDB, + }) + + return dbxshared.DBModule +} + +// Fetch queries the database and returns the result as a slice. If the query +// returns no rows, it returns an empty slice and no error. +func Fetch[T any](sqlo Queryable, stmt SelectStatement) ([]*T, error) { + var result []*T + if err := stmt.Query(sqlo, &result); err != nil && !errors.Is(err, ErrNoRows) { + return nil, err + } + return result, nil +} + +// MustFetch queries the database and returns the result as a slice. If the query +// returns no rows, it returns an empty slice and the desired error. +func MustFetch[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) ([]*T, error) { + result, err := Fetch[T](sqlo, stmt) + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, notFoundErr // return the desired error if no rows found + } + return result, nil // return the fetched results +} + +// FetchOne queries the database and returns a single result. If the query +// returns no rows, it returns nil and no error. +func FetchOne[T any](sqlo Queryable, stmt SelectStatement) (*T, error) { + result, err := Fetch[T](sqlo, stmt.LIMIT(1)) + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, nil // no rows found, return nil + } + return result[0], nil // return the first (and only) result +} + +// MustFetchOne queries the database and returns a single result. If the query +// returns no rows, it returns nil and the desired error. +func MustFetchOne[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) (*T, error) { + result, err := MustFetch[T](sqlo, stmt.LIMIT(1), notFoundErr) + if err != nil { + return nil, err + } + return result[0], nil // return the first (and only) result +} + +// Insert executes an insert statement, returning the last inserted ID or an +// error if the insert fails. +func Insert(sqlo Executable, stmt Statement) (uint64, error) { + res, err := stmt.Exec(sqlo) + if err != nil { + return 0, err + } + + id, err := res.LastInsertId() + if err != nil { + return 0, err + } + + if id < 1 { + return 0, errors.New("inserted ID is less than 1") + } + + return uint64(id), nil +} + +// InsertReturning executes an insert statement that returns the inserted row. +// The statement MUST be a Jet InsertStatement with a RETURNING clause. Returns +// the inserted row object T or an error if the insert fails or no rows are returned. +func InsertReturning[T any](sqlo Queryable, stmt Statement) (*T, error) { + var result T + err := stmt.Query(sqlo, &result) + if err != nil { + return nil, err + } + return &result, nil +} + +// Update executes an update statement, returning an error if the update fails. +func Update(sqlo Executable, stmt Statement) error { + _, err := stmt.Exec(sqlo) + return err +} + +// UpdateAffected executes an update statement and returns the number of rows +// affected and an error if any. +func UpdateAffected(sqlo Executable, stmt Statement) (int64, error) { + res, err := stmt.Exec(sqlo) + if err != nil { + return 0, err + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return 0, err + } + + return rowsAffected, nil +} + +// UpdateReturning executes an update statement that returns the updated row. +// The statement MUST be a Jet UpdateStatement with a RETURNING clause. Returns +// the updated row object T or an error if the update fails or no rows are returned. +func UpdateReturning[T any](sqlo Queryable, stmt Statement) (*T, error) { + var result T + err := stmt.Query(sqlo, &result) + if err != nil { + return nil, err + } + return &result, nil +} + +// setupDB connects to the database. +func setupDB(m *app.Module) error { + if sqlDB != nil && sqlDB.Ping() == nil { + m.Logger().Warn("Database connection already established") + return nil + } + + logArgs := []any{ + "user", config.User, + "name", config.Name, + "uri", config.URI, + "dialect", dialect, + } + + var err error + sqlDB, err = sql.Open(string(dialect), config.ConnectionString(dialect)) + if err != nil { + logArgs = append(logArgs, "err", err) + m.Logger().Error("Couldn't open SQL database", logArgs...) + return err + } + + if err := sqlDB.Ping(); err != nil { + logArgs = append(logArgs, "err", err) + m.Logger().Error("Couldn't ping SQL database", logArgs...) + return err + } + + sqlDB.SetMaxOpenConns(config.MaxConn) + + stats := sqlDB.Stats() + m.Logger().Info( + "Connected to SQL database", + "user", config.User, + "name", config.Name, + "uri", config.URI, + "maxConnections", stats.MaxOpenConnections, + "currConnections", stats.OpenConnections, + ) + + if debugLog { + dbxshared.InitLogger(dialect) // initialize the logger for the specified dialect + } + + return nil +} + +// teardownDB closes the database connection. +func teardownDB(m *app.Module) error { + if sqlDB == nil { + return nil + } + + if err := sqlDB.Close(); err != nil { + m.Logger().Error("Couldn't close database", "err", err) + return err + } + m.Logger().Info("Closed database connection") + return nil } diff --git a/dbxm/mysql.go b/dbxm/mysql.go index b2dc62e..3344ab5 100644 --- a/dbxm/mysql.go +++ b/dbxm/mysql.go @@ -2,179 +2,18 @@ package dbxm import ( "context" - "database/sql" - "errors" "fmt" "slices" "strings" - "gitea.auvem.com/go-toolkit/app" "gitea.auvem.com/go-toolkit/dbx" - dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic" + "gitea.auvem.com/go-toolkit/dbx/internal/dbxshared" "github.com/fatih/color" "github.com/go-jet/jet/v2/mysql" - "github.com/go-jet/jet/v2/qrm" ) -var ( - // ErrNoRows is returned when a query returns no rows. - ErrNoRows = qrm.ErrNoRows - - // sqlDB stores the current SQL database handle. - sqlDB *sql.DB - - // config stores the database connection configuration. - config dbx.DBConfig - - // dbModule is the singleton module instance for the database connection. - dbModule *app.Module -) - -// SQLO returns the current SQL database handle. -func SQLO() *sql.DB { - dbModule.RequireLoaded("dbxm.SQLO requires database module") // ensure the module is loaded before accessing the database - if sqlDB == nil { - panic("SQL database not initialized") - } - return sqlDB -} - -// ModuleDB returns the database module with the provided configuration. -func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module { - if dbModule != nil { - panic("ModuleDB initialized multiple times") - } - - config = cfg // store configuration at package level - - if forceDebugLog { - config.DebugLog = true // force debug logging if requested - } - - dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{ - Setup: setupDB, - Teardown: teardownDB, - }) - - return dbModule -} - -// setupDB connects to the MySQL DB and initializes the goose migration provider. -// If auto migrations are enabled in configuration, latest migrations are applied. -func setupDB(_ *app.Module) error { - if sqlDB != nil && sqlDB.Ping() == nil { - dbModule.Logger().Warn("Database connection already established") - return nil - } - - var err error - db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", - config.User, config.Password, config.Path, config.Name) - - sqlDB, err = sql.Open("mysql", db_path) - if err != nil { - dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) - return err - } - - if err := sqlDB.Ping(); err != nil { - dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) - return err - } - - sqlDB.SetMaxOpenConns(config.MaxConn) - - stats := sqlDB.Stats() - dbModule.Logger().Info( - "Connected to SQL database", - "user", config.User, - "name", config.Name, - "path", config.Path, - "maxConnections", stats.MaxOpenConnections, - "currConnections", stats.OpenConnections, - ) - - if config.DebugLog { - initDBLogger() - } - - return nil -} - -// teardownDB closes the database connection. -func teardownDB(_ *app.Module) error { - if sqlDB == nil { - return nil - } - - if err := sqlDB.Close(); err != nil { - dbModule.Logger().Error("Couldn't close database", "err", err) - return err - } - dbModule.Logger().Info("Closed database connection") - return nil -} - -// MustQuery executes a query and returns an error if any. Filters errors for -// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. -func MustQuery(sqlo dbx.Queryable, stmt mysql.Statement, dest any, notFoundErr error) error { - err := stmt.Query(sqlo, dest) - if err != nil { - if errors.Is(err, qrm.ErrNoRows) { - return notFoundErr - } - return err - } - return nil -} - -// MustInsert requires at least one row to be affected by a query and returns the inserted ID -func MustInsert(sqlo dbx.Executable, stmt mysql.InsertStatement) (uint64, error) { - res, err := stmt.Exec(sqlo) - if err != nil { - return 0, err - } - - id, err := res.LastInsertId() - if err != nil { - return 0, err - } - - if id < 0 { - return 0, fmt.Errorf("MustInsert got invalid ID: %d", id) - } - - return uint64(id), nil -} - -// MustUpdate requires at least one row to be affected by a query -func MustUpdate(sqlo dbx.Executable, stmt mysql.Statement, notFoundErr error) error { - res, err := stmt.Exec(sqlo) - if err != nil { - return err - } - - rows, err := res.RowsAffected() - if err != nil { - return err - } - - if rows == 0 { - return notFoundErr - } - - return nil -} - -// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows -func MightUpdateMany(sqlo dbx.Executable, stmt mysql.Statement) (int, error) { - res, err := stmt.Exec(sqlo) - if err != nil { - return 0, err - } - - rows, err := res.RowsAffected() - return int(rows), err +func init() { + dbxshared.RegisterLogger(dbx.DialectMySQL, &_mysqlLogger{}) } // SoftDelete sets the deleted_at column to the current time @@ -229,14 +68,15 @@ func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool { // DestName returns the name of the type passed as `destTypeStruct` as a string, // normalized for compatibility with the Jet QRM. func DestName(destTypeStruct any, path ...string) string { - return dbxgeneric.DestName(dbModule, destTypeStruct, path...) + return dbxshared.DestName(destTypeStruct, path...) } -// initDBLogger initializes the database statement logger -func initDBLogger() { +type _mysqlLogger struct{} + +func (_mysqlLogger) InitLogger() { mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) { _, args := queryInfo.Statement.Sql() - dbModule.Logger().Debug( + dbx.ModuleDB().Logger().Debug( "Executed SQL query", "args", args, "duration", queryInfo.Duration, diff --git a/dbxp/pxg.go b/dbxp/pxg.go index 6e63e06..5efbf6e 100644 --- a/dbxp/pxg.go +++ b/dbxp/pxg.go @@ -2,179 +2,18 @@ package dbxp import ( "context" - "database/sql" - "errors" "fmt" "slices" "strings" - "gitea.auvem.com/go-toolkit/app" "gitea.auvem.com/go-toolkit/dbx" - dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic" + "gitea.auvem.com/go-toolkit/dbx/internal/dbxshared" "github.com/fatih/color" "github.com/go-jet/jet/v2/postgres" - "github.com/go-jet/jet/v2/qrm" ) -var ( - // ErrNoRows is returned when a query returns no rows. - ErrNoRows = qrm.ErrNoRows - - // sqlDB stores the current SQL database handle. - sqlDB *sql.DB - - // config stores the database connection configuration. - config dbx.DBConfig - - // dbModule is the singleton module instance for the database connection. - dbModule *app.Module -) - -// SQLO returns the current SQL database handle. -func SQLO() *sql.DB { - dbModule.RequireLoaded("dbxp.SQLO requires database module") // ensure the module is loaded before accessing the database - if sqlDB == nil { - panic("SQL database not initialized") - } - return sqlDB -} - -// ModuleDB returns the database module with the provided configuration. -func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module { - if dbModule != nil { - panic("ModuleDB initialized multiple times") - } - - config = cfg // store configuration at package level - - if forceDebugLog { - config.DebugLog = true // force debug logging if requested - } - - dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{ - Setup: setupDB, - Teardown: teardownDB, - }) - - return dbModule -} - -// setupDB connects to the PostgreSQL DB and initializes the goose migration provider. -// If auto migrations are enabled in configuration, latest migrations are applied. -func setupDB(_ *app.Module) error { - if sqlDB != nil && sqlDB.Ping() == nil { - dbModule.Logger().Warn("Database connection already established") - return nil - } - - var err error - db_path := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", - config.User, config.Password, config.Path, config.Name) - - sqlDB, err = sql.Open("postgres", db_path) - if err != nil { - dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) - return err - } - - if err := sqlDB.Ping(); err != nil { - dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) - return err - } - - sqlDB.SetMaxOpenConns(config.MaxConn) - - stats := sqlDB.Stats() - dbModule.Logger().Info( - "Connected to SQL database", - "user", config.User, - "name", config.Name, - "path", config.Path, - "maxConnections", stats.MaxOpenConnections, - "currConnections", stats.OpenConnections, - ) - - if config.DebugLog { - initDBLogger() - } - - return nil -} - -// teardownDB closes the database connection. -func teardownDB(_ *app.Module) error { - if sqlDB == nil { - return nil - } - - if err := sqlDB.Close(); err != nil { - dbModule.Logger().Error("Couldn't close database", "err", err) - return err - } - dbModule.Logger().Info("Closed database connection") - return nil -} - -// MustQuery executes a query and returns an error if any. Filters errors for -// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. -func MustQuery(sqlo dbx.Queryable, stmt postgres.Statement, dest any, notFoundErr error) error { - err := stmt.Query(sqlo, dest) - if err != nil { - if errors.Is(err, qrm.ErrNoRows) { - return notFoundErr - } - return err - } - return nil -} - -// MustInsert requires at least one row to be affected by a query and returns the inserted ID -func MustInsert(sqlo dbx.Executable, stmt postgres.InsertStatement) (uint64, error) { - res, err := stmt.Exec(sqlo) - if err != nil { - return 0, err - } - - id, err := res.LastInsertId() - if err != nil { - return 0, err - } - - if id < 0 { - return 0, fmt.Errorf("MustInsert got invalid ID: %d", id) - } - - return uint64(id), nil -} - -// MustUpdate requires at least one row to be affected by a query -func MustUpdate(sqlo dbx.Executable, stmt postgres.Statement, notFoundErr error) error { - res, err := stmt.Exec(sqlo) - if err != nil { - return err - } - - rows, err := res.RowsAffected() - if err != nil { - return err - } - - if rows == 0 { - return notFoundErr - } - - return nil -} - -// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows -func MightUpdateMany(sqlo dbx.Executable, stmt postgres.Statement) (int, error) { - res, err := stmt.Exec(sqlo) - if err != nil { - return 0, err - } - - rows, err := res.RowsAffected() - return int(rows), err +func init() { + dbxshared.RegisterLogger(dbx.DialectPostgres, &_postgresLogger{}) } // SoftDelete sets the deleted_at column to the current time @@ -229,14 +68,15 @@ func ContainsCol(cols postgres.ColumnList, col postgres.Column) bool { // DestName returns the name of the type passed as `destTypeStruct` as a string, // normalized for compatibility with the Jet QRM. func DestName(destTypeStruct any, path ...string) string { - return dbxgeneric.DestName(dbModule, destTypeStruct, path...) + return dbxshared.DestName(destTypeStruct, path...) } -// initDBLogger initializes the database statement logger -func initDBLogger() { +type _postgresLogger struct{} + +func (_postgresLogger) InitLogger() { postgres.SetQueryLogger(func(ctx context.Context, queryInfo postgres.QueryInfo) { _, args := queryInfo.Statement.Sql() - dbModule.Logger().Debug( + dbx.ModuleDB().Logger().Debug( "Executed SQL query", "args", args, "duration", queryInfo.Duration, diff --git a/generic/generic.go b/internal/dbxshared/dbxshared.go similarity index 60% rename from generic/generic.go rename to internal/dbxshared/dbxshared.go index ba5c261..111d9a9 100644 --- a/generic/generic.go +++ b/internal/dbxshared/dbxshared.go @@ -1,15 +1,13 @@ -package dbxgeneric +package dbxshared import ( "reflect" "strings" - - "gitea.auvem.com/go-toolkit/app" ) // DestName returns the name of the type passed as `destTypeStruct` as a string, // normalized for compatibility with the Jet QRM. -func DestName(mod *app.Module, destTypeStruct any, path ...string) string { +func DestName(destTypeStruct any, path ...string) string { v := reflect.ValueOf(destTypeStruct) for v.Kind() == reflect.Ptr { v = v.Elem() @@ -20,14 +18,14 @@ func DestName(mod *app.Module, destTypeStruct any, path ...string) string { for i, p := range path { if v.Kind() != reflect.Struct { - mod.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], ".")) + DBModule.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], ".")) return "" } v = v.FieldByName(p) if !v.IsValid() { - mod.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], ".")) + DBModule.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], ".")) return "" } diff --git a/internal/dbxshared/logger.go b/internal/dbxshared/logger.go new file mode 100644 index 0000000..8c60e39 --- /dev/null +++ b/internal/dbxshared/logger.go @@ -0,0 +1,30 @@ +package dbxshared + +type Logger interface { + InitLogger() +} + +type dialectString interface { + String() string +} + +var loggerRegistry = make(map[string]Logger) + +// RegisterLogger allows dialect-specific loggers to be registered. +func RegisterLogger(dialect dialectString, logger Logger) { + dialectStr := dialect.String() + if _, exists := loggerRegistry[dialectStr]; exists { + panic("Logger for dialect already registered: " + dialectStr) + } + loggerRegistry[dialectStr] = logger +} + +// InitLogger initializes the logger for a specific dialect. +func InitLogger(dialect dialectString) { + dialectStr := dialect.String() + logger, exists := loggerRegistry[dialectStr] + if !exists { + panic("No logger registered for dialect: " + dialectStr) + } + logger.InitLogger() +} diff --git a/internal/dbxshared/module.go b/internal/dbxshared/module.go new file mode 100644 index 0000000..9956cf8 --- /dev/null +++ b/internal/dbxshared/module.go @@ -0,0 +1,5 @@ +package dbxshared + +import "gitea.auvem.com/go-toolkit/app" + +var DBModule *app.Module diff --git a/utility.go b/utility.go index 7f35e99..1c1ccec 100644 --- a/utility.go +++ b/utility.go @@ -1,9 +1,25 @@ package dbx import ( + "strings" + "github.com/go-jet/jet/v2/mysql" ) +// StringToFilter processes a string to be used as a filter in an SQL LIKE +// statement. It replaces all spaces with % and adds % to the beginning and +// end of the string. +func StringToFilter(str string) string { + // Remove any existing leading or trailing % characters + str = strings.Trim(str, "%") + + // Replace all spaces with % and add % to the beginning and end of the string + str = strings.ReplaceAll(str, " ", "%") + str = "%" + str + "%" + + return str +} + // ExprID converts a list of uint64 values to a list of mysql.Expression values func ExprID(ids []uint64) []mysql.Expression { expressions := make([]mysql.Expression, len(ids))