split mysql specific code into separate package

This commit is contained in:
Elijah Duffy
2025-06-04 18:17:39 -07:00
parent 757483a574
commit d773164227
3 changed files with 77 additions and 64 deletions

56
dbx.go
View File

@@ -2,40 +2,14 @@ package dbx
import ( import (
"database/sql" "database/sql"
"reflect"
"strings" "strings"
"gitea.auvem.com/go-toolkit/app"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
_ "github.com/go-sql-driver/mysql"
) )
// ModuleDBName is the name of the database module. // ModuleDBName is the name of the database module.
const ModuleDBName = "database" const ModuleDBName = "database"
var (
// ErrNoRows is returned when a query returns no rows.
ErrNoRows = qrm.ErrNoRows
// sqlDB stores the current SQL database handle.
sqlDB *sql.DB
// config stores the database connection configuration.
config DBConfig
// dbModule is the singleton module instance for the database connection.
dbModule *app.Module
)
// SQLO returns the current SQL database handle.
func SQLO() *sql.DB {
dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database
if sqlDB == nil {
panic("SQL database not initialized")
}
return sqlDB
}
// Queryable interface is an SQL driver object that can execute SQL statements // Queryable interface is an SQL driver object that can execute SQL statements
// for Jet. // for Jet.
type Queryable interface { type Queryable interface {
@@ -57,36 +31,6 @@ type ExecutableTx interface {
Begin() (*sql.Tx, error) Begin() (*sql.Tx, error)
} }
// DestName returns the name of the type passed as `destTypeStruct` as a string,
// normalized for compatibility with the Jet QRM.
func DestName(destTypeStruct any, path ...string) string {
v := reflect.ValueOf(destTypeStruct)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
destIdent := v.Type().String()
destIdent = destIdent[strings.LastIndex(destIdent, ".")+1:]
for i, p := range path {
if v.Kind() != reflect.Struct {
dbModule.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], "."))
return ""
}
v = v.FieldByName(p)
if !v.IsValid() {
dbModule.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], "."))
return ""
}
destIdent += "." + p
}
return destIdent
}
// StringToFilter processes a string to be used as a filter in an SQL LIKE // StringToFilter processes a string to be used as a filter in an SQL LIKE
// statement. It replaces all spaces with % and adds % to the beginning and // statement. It replaces all spaces with % and adds % to the beginning and
// end of the string. // end of the string.

38
generic/generic.go Normal file
View File

@@ -0,0 +1,38 @@
package dbxgeneric
import (
"reflect"
"strings"
"gitea.auvem.com/go-toolkit/app"
)
// DestName returns the name of the type passed as `destTypeStruct` as a string,
// normalized for compatibility with the Jet QRM.
func DestName(mod *app.Module, destTypeStruct any, path ...string) string {
v := reflect.ValueOf(destTypeStruct)
for v.Kind() == reflect.Ptr {
v = v.Elem()
}
destIdent := v.Type().String()
destIdent = destIdent[strings.LastIndex(destIdent, ".")+1:]
for i, p := range path {
if v.Kind() != reflect.Struct {
mod.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], "."))
return ""
}
v = v.FieldByName(p)
if !v.IsValid() {
mod.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], "."))
return ""
}
destIdent += "." + p
}
return destIdent
}

View File

@@ -1,4 +1,4 @@
package dbx package dbxm
import ( import (
"context" "context"
@@ -9,13 +9,38 @@ import (
"strings" "strings"
"gitea.auvem.com/go-toolkit/app" "gitea.auvem.com/go-toolkit/app"
"gitea.auvem.com/go-toolkit/dbx"
dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic"
"github.com/fatih/color" "github.com/fatih/color"
"github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/qrm"
) )
var (
// ErrNoRows is returned when a query returns no rows.
ErrNoRows = qrm.ErrNoRows
// sqlDB stores the current SQL database handle.
sqlDB *sql.DB
// config stores the database connection configuration.
config dbx.DBConfig
// dbModule is the singleton module instance for the database connection.
dbModule *app.Module
)
// SQLO returns the current SQL database handle.
func SQLO() *sql.DB {
dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database
if sqlDB == nil {
panic("SQL database not initialized")
}
return sqlDB
}
// ModuleDB returns the database module with the provided configuration. // ModuleDB returns the database module with the provided configuration.
func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module {
if dbModule != nil { if dbModule != nil {
panic("ModuleDB initialized multiple times") panic("ModuleDB initialized multiple times")
} }
@@ -26,7 +51,7 @@ func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module {
config.DebugLog = true // force debug logging if requested config.DebugLog = true // force debug logging if requested
} }
dbModule = app.NewModule(ModuleDBName, app.ModuleOpts{ dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{
Setup: setupDB, Setup: setupDB,
Teardown: teardownDB, Teardown: teardownDB,
}) })
@@ -92,7 +117,7 @@ func teardownDB(_ *app.Module) error {
// MustQuery executes a query and returns an error if any. Filters errors for // MustQuery executes a query and returns an error if any. Filters errors for
// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. // db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case.
func MustQuery(sqlo Queryable, stmt mysql.Statement, dest any, notFoundErr error) error { func MustQuery(sqlo dbx.Queryable, stmt mysql.Statement, dest any, notFoundErr error) error {
err := stmt.Query(sqlo, dest) err := stmt.Query(sqlo, dest)
if err != nil { if err != nil {
if errors.Is(err, qrm.ErrNoRows) { if errors.Is(err, qrm.ErrNoRows) {
@@ -104,7 +129,7 @@ func MustQuery(sqlo Queryable, stmt mysql.Statement, dest any, notFoundErr error
} }
// MustInsert requires at least one row to be affected by a query and returns the inserted ID // MustInsert requires at least one row to be affected by a query and returns the inserted ID
func MustInsert(sqlo Executable, stmt mysql.InsertStatement) (uint64, error) { func MustInsert(sqlo dbx.Executable, stmt mysql.InsertStatement) (uint64, error) {
res, err := stmt.Exec(sqlo) res, err := stmt.Exec(sqlo)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -123,7 +148,7 @@ func MustInsert(sqlo Executable, stmt mysql.InsertStatement) (uint64, error) {
} }
// MustUpdate requires at least one row to be affected by a query // MustUpdate requires at least one row to be affected by a query
func MustUpdate(sqlo Executable, stmt mysql.Statement, notFoundErr error) error { func MustUpdate(sqlo dbx.Executable, stmt mysql.Statement, notFoundErr error) error {
res, err := stmt.Exec(sqlo) res, err := stmt.Exec(sqlo)
if err != nil { if err != nil {
return err return err
@@ -142,7 +167,7 @@ func MustUpdate(sqlo Executable, stmt mysql.Statement, notFoundErr error) error
} }
// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows // MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows
func MightUpdateMany(sqlo Executable, stmt mysql.Statement) (int, error) { func MightUpdateMany(sqlo dbx.Executable, stmt mysql.Statement) (int, error) {
res, err := stmt.Exec(sqlo) res, err := stmt.Exec(sqlo)
if err != nil { if err != nil {
return 0, err return 0, err
@@ -153,7 +178,7 @@ func MightUpdateMany(sqlo Executable, stmt mysql.Statement) (int, error) {
} }
// SoftDelete sets the deleted_at column to the current time // SoftDelete sets the deleted_at column to the current time
func SoftDelete(sqlo Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) { func SoftDelete(sqlo dbx.Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) {
stmt := tbl.UPDATE().WHERE(conds) stmt := tbl.UPDATE().WHERE(conds)
query := stmt.DebugSql() query := stmt.DebugSql()
@@ -201,6 +226,12 @@ func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool {
return slices.Contains(cols, col) return slices.Contains(cols, col)
} }
// DestName returns the name of the type passed as `destTypeStruct` as a string,
// normalized for compatibility with the Jet QRM.
func DestName(destTypeStruct any, path ...string) string {
return dbxgeneric.DestName(dbModule, destTypeStruct, path...)
}
// initDBLogger initializes the database statement logger // initDBLogger initializes the database statement logger
func initDBLogger() { func initDBLogger() {
mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) { mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) {