diff --git a/dbx.go b/dbx.go index 9b9e2e6..3ffb86b 100644 --- a/dbx.go +++ b/dbx.go @@ -2,40 +2,14 @@ package dbx import ( "database/sql" - "reflect" "strings" - "gitea.auvem.com/go-toolkit/app" "github.com/go-jet/jet/v2/qrm" - _ "github.com/go-sql-driver/mysql" ) // ModuleDBName is the name of the database module. const ModuleDBName = "database" -var ( - // ErrNoRows is returned when a query returns no rows. - ErrNoRows = qrm.ErrNoRows - - // sqlDB stores the current SQL database handle. - sqlDB *sql.DB - - // config stores the database connection configuration. - config DBConfig - - // dbModule is the singleton module instance for the database connection. - dbModule *app.Module -) - -// SQLO returns the current SQL database handle. -func SQLO() *sql.DB { - dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database - if sqlDB == nil { - panic("SQL database not initialized") - } - return sqlDB -} - // Queryable interface is an SQL driver object that can execute SQL statements // for Jet. type Queryable interface { @@ -57,36 +31,6 @@ type ExecutableTx interface { Begin() (*sql.Tx, error) } -// DestName returns the name of the type passed as `destTypeStruct` as a string, -// normalized for compatibility with the Jet QRM. -func DestName(destTypeStruct any, path ...string) string { - v := reflect.ValueOf(destTypeStruct) - for v.Kind() == reflect.Ptr { - v = v.Elem() - } - - destIdent := v.Type().String() - destIdent = destIdent[strings.LastIndex(destIdent, ".")+1:] - - for i, p := range path { - if v.Kind() != reflect.Struct { - dbModule.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], ".")) - return "" - } - - v = v.FieldByName(p) - - if !v.IsValid() { - dbModule.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], ".")) - return "" - } - - destIdent += "." + p - } - - return destIdent -} - // StringToFilter processes a string to be used as a filter in an SQL LIKE // statement. It replaces all spaces with % and adds % to the beginning and // end of the string. diff --git a/generic/generic.go b/generic/generic.go new file mode 100644 index 0000000..ba5c261 --- /dev/null +++ b/generic/generic.go @@ -0,0 +1,38 @@ +package dbxgeneric + +import ( + "reflect" + "strings" + + "gitea.auvem.com/go-toolkit/app" +) + +// DestName returns the name of the type passed as `destTypeStruct` as a string, +// normalized for compatibility with the Jet QRM. +func DestName(mod *app.Module, destTypeStruct any, path ...string) string { + v := reflect.ValueOf(destTypeStruct) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + + destIdent := v.Type().String() + destIdent = destIdent[strings.LastIndex(destIdent, ".")+1:] + + for i, p := range path { + if v.Kind() != reflect.Struct { + mod.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], ".")) + return "" + } + + v = v.FieldByName(p) + + if !v.IsValid() { + mod.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], ".")) + return "" + } + + destIdent += "." + p + } + + return destIdent +} diff --git a/mysql.go b/mysql/mysql.go similarity index 75% rename from mysql.go rename to mysql/mysql.go index ffec3f2..3eb53bb 100644 --- a/mysql.go +++ b/mysql/mysql.go @@ -1,4 +1,4 @@ -package dbx +package dbxm import ( "context" @@ -9,13 +9,38 @@ import ( "strings" "gitea.auvem.com/go-toolkit/app" + "gitea.auvem.com/go-toolkit/dbx" + dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic" "github.com/fatih/color" "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/qrm" ) +var ( + // ErrNoRows is returned when a query returns no rows. + ErrNoRows = qrm.ErrNoRows + + // sqlDB stores the current SQL database handle. + sqlDB *sql.DB + + // config stores the database connection configuration. + config dbx.DBConfig + + // dbModule is the singleton module instance for the database connection. + dbModule *app.Module +) + +// SQLO returns the current SQL database handle. +func SQLO() *sql.DB { + dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database + if sqlDB == nil { + panic("SQL database not initialized") + } + return sqlDB +} + // ModuleDB returns the database module with the provided configuration. -func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { +func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module { if dbModule != nil { panic("ModuleDB initialized multiple times") } @@ -26,7 +51,7 @@ func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { config.DebugLog = true // force debug logging if requested } - dbModule = app.NewModule(ModuleDBName, app.ModuleOpts{ + dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{ Setup: setupDB, Teardown: teardownDB, }) @@ -92,7 +117,7 @@ func teardownDB(_ *app.Module) error { // MustQuery executes a query and returns an error if any. Filters errors for // db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. -func MustQuery(sqlo Queryable, stmt mysql.Statement, dest any, notFoundErr error) error { +func MustQuery(sqlo dbx.Queryable, stmt mysql.Statement, dest any, notFoundErr error) error { err := stmt.Query(sqlo, dest) if err != nil { if errors.Is(err, qrm.ErrNoRows) { @@ -104,7 +129,7 @@ func MustQuery(sqlo Queryable, stmt mysql.Statement, dest any, notFoundErr error } // MustInsert requires at least one row to be affected by a query and returns the inserted ID -func MustInsert(sqlo Executable, stmt mysql.InsertStatement) (uint64, error) { +func MustInsert(sqlo dbx.Executable, stmt mysql.InsertStatement) (uint64, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err @@ -123,7 +148,7 @@ func MustInsert(sqlo Executable, stmt mysql.InsertStatement) (uint64, error) { } // MustUpdate requires at least one row to be affected by a query -func MustUpdate(sqlo Executable, stmt mysql.Statement, notFoundErr error) error { +func MustUpdate(sqlo dbx.Executable, stmt mysql.Statement, notFoundErr error) error { res, err := stmt.Exec(sqlo) if err != nil { return err @@ -142,7 +167,7 @@ func MustUpdate(sqlo Executable, stmt mysql.Statement, notFoundErr error) error } // MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows -func MightUpdateMany(sqlo Executable, stmt mysql.Statement) (int, error) { +func MightUpdateMany(sqlo dbx.Executable, stmt mysql.Statement) (int, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err @@ -153,7 +178,7 @@ func MightUpdateMany(sqlo Executable, stmt mysql.Statement) (int, error) { } // SoftDelete sets the deleted_at column to the current time -func SoftDelete(sqlo Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) { +func SoftDelete(sqlo dbx.Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) { stmt := tbl.UPDATE().WHERE(conds) query := stmt.DebugSql() @@ -201,6 +226,12 @@ func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool { return slices.Contains(cols, col) } +// DestName returns the name of the type passed as `destTypeStruct` as a string, +// normalized for compatibility with the Jet QRM. +func DestName(destTypeStruct any, path ...string) string { + return dbxgeneric.DestName(dbModule, destTypeStruct, path...) +} + // initDBLogger initializes the database statement logger func initDBLogger() { mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) {