refactor with improved developer ux in mind
- Most functions now live directly in the dbx package - dbxp and dbxm are now ONLY the few functions that cannot be shared
This commit is contained in:
274
dbx.go
274
dbx.go
@@ -1,14 +1,32 @@
|
|||||||
package dbx
|
package dbx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"strings"
|
"errors"
|
||||||
|
|
||||||
|
"gitea.auvem.com/go-toolkit/app"
|
||||||
|
"gitea.auvem.com/go-toolkit/dbx/internal/dbxshared"
|
||||||
"github.com/go-jet/jet/v2/qrm"
|
"github.com/go-jet/jet/v2/qrm"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
// ModuleDBName is the name of the database module.
|
// ModuleDBName is the name of the database module.
|
||||||
const ModuleDBName = "database"
|
ModuleDBName = "database"
|
||||||
|
|
||||||
|
// DialectPostgres is the PostgreSQL dialect.
|
||||||
|
DialectPostgres Dialect = "postgres"
|
||||||
|
// DialectMySQL is the MySQL dialect.
|
||||||
|
DialectMySQL Dialect = "mysql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialect is the SQL dialect used by the database connection.
|
||||||
|
type Dialect string
|
||||||
|
|
||||||
|
// String implements the Stringer interface for Dialect.
|
||||||
|
func (d Dialect) String() string {
|
||||||
|
return string(d)
|
||||||
|
}
|
||||||
|
|
||||||
// SQLOFunc is a function that returns a *sql.DB pointer.
|
// SQLOFunc is a function that returns a *sql.DB pointer.
|
||||||
type SQLOFunc = func() *sql.DB
|
type SQLOFunc = func() *sql.DB
|
||||||
@@ -34,16 +52,244 @@ type ExecutableTx interface {
|
|||||||
Begin() (*sql.Tx, error)
|
Begin() (*sql.Tx, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringToFilter processes a string to be used as a filter in an SQL LIKE
|
// Statement is a common Jet statement for all SQL operations.
|
||||||
// statement. It replaces all spaces with % and adds % to the beginning and
|
type Statement interface {
|
||||||
// end of the string.
|
Query(db qrm.Queryable, destination any) error
|
||||||
func StringToFilter(str string) string {
|
QueryContext(ctx context.Context, db qrm.Queryable, destination any) error
|
||||||
// Remove any existing leading or trailing % characters
|
Exec(db qrm.Executable) (sql.Result, error)
|
||||||
str = strings.Trim(str, "%")
|
ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error)
|
||||||
|
}
|
||||||
// Replace all spaces with % and add % to the beginning and end of the string
|
|
||||||
str = strings.ReplaceAll(str, " ", "%")
|
// SelectStatement is a Jet statement that can be executed to fetch rows from the database.
|
||||||
str = "%" + str + "%"
|
type SelectStatement interface {
|
||||||
|
Statement
|
||||||
return str
|
LIMIT(limit int64) SelectStatement
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNoRows is returned when a query returns no rows.
|
||||||
|
ErrNoRows = qrm.ErrNoRows
|
||||||
|
|
||||||
|
// sqlDB stores the current SQL database handle.
|
||||||
|
sqlDB *sql.DB
|
||||||
|
|
||||||
|
// config stores the database connection configuration.
|
||||||
|
config *DBConfig
|
||||||
|
|
||||||
|
// dialect is the SQL dialect used by the database connection.
|
||||||
|
dialect Dialect
|
||||||
|
|
||||||
|
// debugLog indicates whether debug logging is enabled.
|
||||||
|
debugLog bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLO returns the current SQL database handle.
|
||||||
|
func SQLO() *sql.DB {
|
||||||
|
dbxshared.DBModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database
|
||||||
|
if sqlDB == nil {
|
||||||
|
panic("SQL database not initialized")
|
||||||
|
}
|
||||||
|
return sqlDB
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModuleDB returns the existing database module, or panics if it has not
|
||||||
|
// been initialized yet.
|
||||||
|
func ModuleDB() *app.Module {
|
||||||
|
if dbxshared.DBModule == nil {
|
||||||
|
panic("ModuleDB not initialized yet")
|
||||||
|
}
|
||||||
|
return dbxshared.DBModule
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitModuleDB returns the database module with the provided configuration.
|
||||||
|
func InitModuleDB(dialect Dialect, cfg *DBConfig, forceDebugLog bool) *app.Module {
|
||||||
|
if dbxshared.DBModule != nil {
|
||||||
|
panic("ModuleDB initialized multiple times")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
panic("ModuleDB requires a non-nil DBConfig")
|
||||||
|
}
|
||||||
|
|
||||||
|
config = cfg // store configuration at package level
|
||||||
|
debugLog = cfg.DebugLog || forceDebugLog // force debug logging if requested
|
||||||
|
|
||||||
|
dbxshared.DBModule = app.NewModule(ModuleDBName, app.ModuleOpts{
|
||||||
|
Setup: setupDB,
|
||||||
|
Teardown: teardownDB,
|
||||||
|
})
|
||||||
|
|
||||||
|
return dbxshared.DBModule
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch queries the database and returns the result as a slice. If the query
|
||||||
|
// returns no rows, it returns an empty slice and no error.
|
||||||
|
func Fetch[T any](sqlo Queryable, stmt SelectStatement) ([]*T, error) {
|
||||||
|
var result []*T
|
||||||
|
if err := stmt.Query(sqlo, &result); err != nil && !errors.Is(err, ErrNoRows) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFetch queries the database and returns the result as a slice. If the query
|
||||||
|
// returns no rows, it returns an empty slice and the desired error.
|
||||||
|
func MustFetch[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) ([]*T, error) {
|
||||||
|
result, err := Fetch[T](sqlo, stmt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, notFoundErr // return the desired error if no rows found
|
||||||
|
}
|
||||||
|
return result, nil // return the fetched results
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchOne queries the database and returns a single result. If the query
|
||||||
|
// returns no rows, it returns nil and no error.
|
||||||
|
func FetchOne[T any](sqlo Queryable, stmt SelectStatement) (*T, error) {
|
||||||
|
result, err := Fetch[T](sqlo, stmt.LIMIT(1))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, nil // no rows found, return nil
|
||||||
|
}
|
||||||
|
return result[0], nil // return the first (and only) result
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFetchOne queries the database and returns a single result. If the query
|
||||||
|
// returns no rows, it returns nil and the desired error.
|
||||||
|
func MustFetchOne[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) (*T, error) {
|
||||||
|
result, err := MustFetch[T](sqlo, stmt.LIMIT(1), notFoundErr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result[0], nil // return the first (and only) result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert executes an insert statement, returning the last inserted ID or an
|
||||||
|
// error if the insert fails.
|
||||||
|
func Insert(sqlo Executable, stmt Statement) (uint64, error) {
|
||||||
|
res, err := stmt.Exec(sqlo)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if id < 1 {
|
||||||
|
return 0, errors.New("inserted ID is less than 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint64(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertReturning executes an insert statement that returns the inserted row.
|
||||||
|
// The statement MUST be a Jet InsertStatement with a RETURNING clause. Returns
|
||||||
|
// the inserted row object T or an error if the insert fails or no rows are returned.
|
||||||
|
func InsertReturning[T any](sqlo Queryable, stmt Statement) (*T, error) {
|
||||||
|
var result T
|
||||||
|
err := stmt.Query(sqlo, &result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update executes an update statement, returning an error if the update fails.
|
||||||
|
func Update(sqlo Executable, stmt Statement) error {
|
||||||
|
_, err := stmt.Exec(sqlo)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAffected executes an update statement and returns the number of rows
|
||||||
|
// affected and an error if any.
|
||||||
|
func UpdateAffected(sqlo Executable, stmt Statement) (int64, error) {
|
||||||
|
res, err := stmt.Exec(sqlo)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return rowsAffected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateReturning executes an update statement that returns the updated row.
|
||||||
|
// The statement MUST be a Jet UpdateStatement with a RETURNING clause. Returns
|
||||||
|
// the updated row object T or an error if the update fails or no rows are returned.
|
||||||
|
func UpdateReturning[T any](sqlo Queryable, stmt Statement) (*T, error) {
|
||||||
|
var result T
|
||||||
|
err := stmt.Query(sqlo, &result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDB connects to the database.
|
||||||
|
func setupDB(m *app.Module) error {
|
||||||
|
if sqlDB != nil && sqlDB.Ping() == nil {
|
||||||
|
m.Logger().Warn("Database connection already established")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logArgs := []any{
|
||||||
|
"user", config.User,
|
||||||
|
"name", config.Name,
|
||||||
|
"uri", config.URI,
|
||||||
|
"dialect", dialect,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
sqlDB, err = sql.Open(string(dialect), config.ConnectionString(dialect))
|
||||||
|
if err != nil {
|
||||||
|
logArgs = append(logArgs, "err", err)
|
||||||
|
m.Logger().Error("Couldn't open SQL database", logArgs...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sqlDB.Ping(); err != nil {
|
||||||
|
logArgs = append(logArgs, "err", err)
|
||||||
|
m.Logger().Error("Couldn't ping SQL database", logArgs...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB.SetMaxOpenConns(config.MaxConn)
|
||||||
|
|
||||||
|
stats := sqlDB.Stats()
|
||||||
|
m.Logger().Info(
|
||||||
|
"Connected to SQL database",
|
||||||
|
"user", config.User,
|
||||||
|
"name", config.Name,
|
||||||
|
"uri", config.URI,
|
||||||
|
"maxConnections", stats.MaxOpenConnections,
|
||||||
|
"currConnections", stats.OpenConnections,
|
||||||
|
)
|
||||||
|
|
||||||
|
if debugLog {
|
||||||
|
dbxshared.InitLogger(dialect) // initialize the logger for the specified dialect
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// teardownDB closes the database connection.
|
||||||
|
func teardownDB(m *app.Module) error {
|
||||||
|
if sqlDB == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sqlDB.Close(); err != nil {
|
||||||
|
m.Logger().Error("Couldn't close database", "err", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.Logger().Info("Closed database connection")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
176
dbxm/mysql.go
176
dbxm/mysql.go
@@ -2,179 +2,18 @@ package dbxm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.auvem.com/go-toolkit/app"
|
|
||||||
"gitea.auvem.com/go-toolkit/dbx"
|
"gitea.auvem.com/go-toolkit/dbx"
|
||||||
dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic"
|
"gitea.auvem.com/go-toolkit/dbx/internal/dbxshared"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/go-jet/jet/v2/mysql"
|
"github.com/go-jet/jet/v2/mysql"
|
||||||
"github.com/go-jet/jet/v2/qrm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func init() {
|
||||||
// ErrNoRows is returned when a query returns no rows.
|
dbxshared.RegisterLogger(dbx.DialectMySQL, &_mysqlLogger{})
|
||||||
ErrNoRows = qrm.ErrNoRows
|
|
||||||
|
|
||||||
// sqlDB stores the current SQL database handle.
|
|
||||||
sqlDB *sql.DB
|
|
||||||
|
|
||||||
// config stores the database connection configuration.
|
|
||||||
config dbx.DBConfig
|
|
||||||
|
|
||||||
// dbModule is the singleton module instance for the database connection.
|
|
||||||
dbModule *app.Module
|
|
||||||
)
|
|
||||||
|
|
||||||
// SQLO returns the current SQL database handle.
|
|
||||||
func SQLO() *sql.DB {
|
|
||||||
dbModule.RequireLoaded("dbxm.SQLO requires database module") // ensure the module is loaded before accessing the database
|
|
||||||
if sqlDB == nil {
|
|
||||||
panic("SQL database not initialized")
|
|
||||||
}
|
|
||||||
return sqlDB
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleDB returns the database module with the provided configuration.
|
|
||||||
func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module {
|
|
||||||
if dbModule != nil {
|
|
||||||
panic("ModuleDB initialized multiple times")
|
|
||||||
}
|
|
||||||
|
|
||||||
config = cfg // store configuration at package level
|
|
||||||
|
|
||||||
if forceDebugLog {
|
|
||||||
config.DebugLog = true // force debug logging if requested
|
|
||||||
}
|
|
||||||
|
|
||||||
dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{
|
|
||||||
Setup: setupDB,
|
|
||||||
Teardown: teardownDB,
|
|
||||||
})
|
|
||||||
|
|
||||||
return dbModule
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupDB connects to the MySQL DB and initializes the goose migration provider.
|
|
||||||
// If auto migrations are enabled in configuration, latest migrations are applied.
|
|
||||||
func setupDB(_ *app.Module) error {
|
|
||||||
if sqlDB != nil && sqlDB.Ping() == nil {
|
|
||||||
dbModule.Logger().Warn("Database connection already established")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
||||||
config.User, config.Password, config.Path, config.Name)
|
|
||||||
|
|
||||||
sqlDB, err = sql.Open("mysql", db_path)
|
|
||||||
if err != nil {
|
|
||||||
dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sqlDB.Ping(); err != nil {
|
|
||||||
dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlDB.SetMaxOpenConns(config.MaxConn)
|
|
||||||
|
|
||||||
stats := sqlDB.Stats()
|
|
||||||
dbModule.Logger().Info(
|
|
||||||
"Connected to SQL database",
|
|
||||||
"user", config.User,
|
|
||||||
"name", config.Name,
|
|
||||||
"path", config.Path,
|
|
||||||
"maxConnections", stats.MaxOpenConnections,
|
|
||||||
"currConnections", stats.OpenConnections,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.DebugLog {
|
|
||||||
initDBLogger()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// teardownDB closes the database connection.
|
|
||||||
func teardownDB(_ *app.Module) error {
|
|
||||||
if sqlDB == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sqlDB.Close(); err != nil {
|
|
||||||
dbModule.Logger().Error("Couldn't close database", "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dbModule.Logger().Info("Closed database connection")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustQuery executes a query and returns an error if any. Filters errors for
|
|
||||||
// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case.
|
|
||||||
func MustQuery(sqlo dbx.Queryable, stmt mysql.Statement, dest any, notFoundErr error) error {
|
|
||||||
err := stmt.Query(sqlo, dest)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, qrm.ErrNoRows) {
|
|
||||||
return notFoundErr
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInsert requires at least one row to be affected by a query and returns the inserted ID
|
|
||||||
func MustInsert(sqlo dbx.Executable, stmt mysql.InsertStatement) (uint64, error) {
|
|
||||||
res, err := stmt.Exec(sqlo)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := res.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if id < 0 {
|
|
||||||
return 0, fmt.Errorf("MustInsert got invalid ID: %d", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint64(id), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUpdate requires at least one row to be affected by a query
|
|
||||||
func MustUpdate(sqlo dbx.Executable, stmt mysql.Statement, notFoundErr error) error {
|
|
||||||
res, err := stmt.Exec(sqlo)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows == 0 {
|
|
||||||
return notFoundErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows
|
|
||||||
func MightUpdateMany(sqlo dbx.Executable, stmt mysql.Statement) (int, error) {
|
|
||||||
res, err := stmt.Exec(sqlo)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := res.RowsAffected()
|
|
||||||
return int(rows), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoftDelete sets the deleted_at column to the current time
|
// SoftDelete sets the deleted_at column to the current time
|
||||||
@@ -229,14 +68,15 @@ func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool {
|
|||||||
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
||||||
// normalized for compatibility with the Jet QRM.
|
// normalized for compatibility with the Jet QRM.
|
||||||
func DestName(destTypeStruct any, path ...string) string {
|
func DestName(destTypeStruct any, path ...string) string {
|
||||||
return dbxgeneric.DestName(dbModule, destTypeStruct, path...)
|
return dbxshared.DestName(destTypeStruct, path...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initDBLogger initializes the database statement logger
|
type _mysqlLogger struct{}
|
||||||
func initDBLogger() {
|
|
||||||
|
func (_mysqlLogger) InitLogger() {
|
||||||
mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) {
|
mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) {
|
||||||
_, args := queryInfo.Statement.Sql()
|
_, args := queryInfo.Statement.Sql()
|
||||||
dbModule.Logger().Debug(
|
dbx.ModuleDB().Logger().Debug(
|
||||||
"Executed SQL query",
|
"Executed SQL query",
|
||||||
"args", args,
|
"args", args,
|
||||||
"duration", queryInfo.Duration,
|
"duration", queryInfo.Duration,
|
||||||
|
|||||||
176
dbxp/pxg.go
176
dbxp/pxg.go
@@ -2,179 +2,18 @@ package dbxp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.auvem.com/go-toolkit/app"
|
|
||||||
"gitea.auvem.com/go-toolkit/dbx"
|
"gitea.auvem.com/go-toolkit/dbx"
|
||||||
dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic"
|
"gitea.auvem.com/go-toolkit/dbx/internal/dbxshared"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/go-jet/jet/v2/postgres"
|
"github.com/go-jet/jet/v2/postgres"
|
||||||
"github.com/go-jet/jet/v2/qrm"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func init() {
|
||||||
// ErrNoRows is returned when a query returns no rows.
|
dbxshared.RegisterLogger(dbx.DialectPostgres, &_postgresLogger{})
|
||||||
ErrNoRows = qrm.ErrNoRows
|
|
||||||
|
|
||||||
// sqlDB stores the current SQL database handle.
|
|
||||||
sqlDB *sql.DB
|
|
||||||
|
|
||||||
// config stores the database connection configuration.
|
|
||||||
config dbx.DBConfig
|
|
||||||
|
|
||||||
// dbModule is the singleton module instance for the database connection.
|
|
||||||
dbModule *app.Module
|
|
||||||
)
|
|
||||||
|
|
||||||
// SQLO returns the current SQL database handle.
|
|
||||||
func SQLO() *sql.DB {
|
|
||||||
dbModule.RequireLoaded("dbxp.SQLO requires database module") // ensure the module is loaded before accessing the database
|
|
||||||
if sqlDB == nil {
|
|
||||||
panic("SQL database not initialized")
|
|
||||||
}
|
|
||||||
return sqlDB
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModuleDB returns the database module with the provided configuration.
|
|
||||||
func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module {
|
|
||||||
if dbModule != nil {
|
|
||||||
panic("ModuleDB initialized multiple times")
|
|
||||||
}
|
|
||||||
|
|
||||||
config = cfg // store configuration at package level
|
|
||||||
|
|
||||||
if forceDebugLog {
|
|
||||||
config.DebugLog = true // force debug logging if requested
|
|
||||||
}
|
|
||||||
|
|
||||||
dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{
|
|
||||||
Setup: setupDB,
|
|
||||||
Teardown: teardownDB,
|
|
||||||
})
|
|
||||||
|
|
||||||
return dbModule
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupDB connects to the PostgreSQL DB and initializes the goose migration provider.
|
|
||||||
// If auto migrations are enabled in configuration, latest migrations are applied.
|
|
||||||
func setupDB(_ *app.Module) error {
|
|
||||||
if sqlDB != nil && sqlDB.Ping() == nil {
|
|
||||||
dbModule.Logger().Warn("Database connection already established")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
db_path := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable",
|
|
||||||
config.User, config.Password, config.Path, config.Name)
|
|
||||||
|
|
||||||
sqlDB, err = sql.Open("postgres", db_path)
|
|
||||||
if err != nil {
|
|
||||||
dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sqlDB.Ping(); err != nil {
|
|
||||||
dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlDB.SetMaxOpenConns(config.MaxConn)
|
|
||||||
|
|
||||||
stats := sqlDB.Stats()
|
|
||||||
dbModule.Logger().Info(
|
|
||||||
"Connected to SQL database",
|
|
||||||
"user", config.User,
|
|
||||||
"name", config.Name,
|
|
||||||
"path", config.Path,
|
|
||||||
"maxConnections", stats.MaxOpenConnections,
|
|
||||||
"currConnections", stats.OpenConnections,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.DebugLog {
|
|
||||||
initDBLogger()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// teardownDB closes the database connection.
|
|
||||||
func teardownDB(_ *app.Module) error {
|
|
||||||
if sqlDB == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := sqlDB.Close(); err != nil {
|
|
||||||
dbModule.Logger().Error("Couldn't close database", "err", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dbModule.Logger().Info("Closed database connection")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustQuery executes a query and returns an error if any. Filters errors for
|
|
||||||
// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case.
|
|
||||||
func MustQuery(sqlo dbx.Queryable, stmt postgres.Statement, dest any, notFoundErr error) error {
|
|
||||||
err := stmt.Query(sqlo, dest)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, qrm.ErrNoRows) {
|
|
||||||
return notFoundErr
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInsert requires at least one row to be affected by a query and returns the inserted ID
|
|
||||||
func MustInsert(sqlo dbx.Executable, stmt postgres.InsertStatement) (uint64, error) {
|
|
||||||
res, err := stmt.Exec(sqlo)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := res.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if id < 0 {
|
|
||||||
return 0, fmt.Errorf("MustInsert got invalid ID: %d", id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint64(id), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUpdate requires at least one row to be affected by a query
|
|
||||||
func MustUpdate(sqlo dbx.Executable, stmt postgres.Statement, notFoundErr error) error {
|
|
||||||
res, err := stmt.Exec(sqlo)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows == 0 {
|
|
||||||
return notFoundErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows
|
|
||||||
func MightUpdateMany(sqlo dbx.Executable, stmt postgres.Statement) (int, error) {
|
|
||||||
res, err := stmt.Exec(sqlo)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := res.RowsAffected()
|
|
||||||
return int(rows), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SoftDelete sets the deleted_at column to the current time
|
// SoftDelete sets the deleted_at column to the current time
|
||||||
@@ -229,14 +68,15 @@ func ContainsCol(cols postgres.ColumnList, col postgres.Column) bool {
|
|||||||
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
||||||
// normalized for compatibility with the Jet QRM.
|
// normalized for compatibility with the Jet QRM.
|
||||||
func DestName(destTypeStruct any, path ...string) string {
|
func DestName(destTypeStruct any, path ...string) string {
|
||||||
return dbxgeneric.DestName(dbModule, destTypeStruct, path...)
|
return dbxshared.DestName(destTypeStruct, path...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initDBLogger initializes the database statement logger
|
type _postgresLogger struct{}
|
||||||
func initDBLogger() {
|
|
||||||
|
func (_postgresLogger) InitLogger() {
|
||||||
postgres.SetQueryLogger(func(ctx context.Context, queryInfo postgres.QueryInfo) {
|
postgres.SetQueryLogger(func(ctx context.Context, queryInfo postgres.QueryInfo) {
|
||||||
_, args := queryInfo.Statement.Sql()
|
_, args := queryInfo.Statement.Sql()
|
||||||
dbModule.Logger().Debug(
|
dbx.ModuleDB().Logger().Debug(
|
||||||
"Executed SQL query",
|
"Executed SQL query",
|
||||||
"args", args,
|
"args", args,
|
||||||
"duration", queryInfo.Duration,
|
"duration", queryInfo.Duration,
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
package dbxgeneric
|
package dbxshared
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.auvem.com/go-toolkit/app"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
||||||
// normalized for compatibility with the Jet QRM.
|
// normalized for compatibility with the Jet QRM.
|
||||||
func DestName(mod *app.Module, destTypeStruct any, path ...string) string {
|
func DestName(destTypeStruct any, path ...string) string {
|
||||||
v := reflect.ValueOf(destTypeStruct)
|
v := reflect.ValueOf(destTypeStruct)
|
||||||
for v.Kind() == reflect.Ptr {
|
for v.Kind() == reflect.Ptr {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
@@ -20,14 +18,14 @@ func DestName(mod *app.Module, destTypeStruct any, path ...string) string {
|
|||||||
|
|
||||||
for i, p := range path {
|
for i, p := range path {
|
||||||
if v.Kind() != reflect.Struct {
|
if v.Kind() != reflect.Struct {
|
||||||
mod.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], "."))
|
DBModule.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], "."))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
v = v.FieldByName(p)
|
v = v.FieldByName(p)
|
||||||
|
|
||||||
if !v.IsValid() {
|
if !v.IsValid() {
|
||||||
mod.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], "."))
|
DBModule.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], "."))
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
30
internal/dbxshared/logger.go
Normal file
30
internal/dbxshared/logger.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package dbxshared
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
InitLogger()
|
||||||
|
}
|
||||||
|
|
||||||
|
type dialectString interface {
|
||||||
|
String() string
|
||||||
|
}
|
||||||
|
|
||||||
|
var loggerRegistry = make(map[string]Logger)
|
||||||
|
|
||||||
|
// RegisterLogger allows dialect-specific loggers to be registered.
|
||||||
|
func RegisterLogger(dialect dialectString, logger Logger) {
|
||||||
|
dialectStr := dialect.String()
|
||||||
|
if _, exists := loggerRegistry[dialectStr]; exists {
|
||||||
|
panic("Logger for dialect already registered: " + dialectStr)
|
||||||
|
}
|
||||||
|
loggerRegistry[dialectStr] = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitLogger initializes the logger for a specific dialect.
|
||||||
|
func InitLogger(dialect dialectString) {
|
||||||
|
dialectStr := dialect.String()
|
||||||
|
logger, exists := loggerRegistry[dialectStr]
|
||||||
|
if !exists {
|
||||||
|
panic("No logger registered for dialect: " + dialectStr)
|
||||||
|
}
|
||||||
|
logger.InitLogger()
|
||||||
|
}
|
||||||
5
internal/dbxshared/module.go
Normal file
5
internal/dbxshared/module.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package dbxshared
|
||||||
|
|
||||||
|
import "gitea.auvem.com/go-toolkit/app"
|
||||||
|
|
||||||
|
var DBModule *app.Module
|
||||||
16
utility.go
16
utility.go
@@ -1,9 +1,25 @@
|
|||||||
package dbx
|
package dbx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/go-jet/jet/v2/mysql"
|
"github.com/go-jet/jet/v2/mysql"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StringToFilter processes a string to be used as a filter in an SQL LIKE
|
||||||
|
// statement. It replaces all spaces with % and adds % to the beginning and
|
||||||
|
// end of the string.
|
||||||
|
func StringToFilter(str string) string {
|
||||||
|
// Remove any existing leading or trailing % characters
|
||||||
|
str = strings.Trim(str, "%")
|
||||||
|
|
||||||
|
// Replace all spaces with % and add % to the beginning and end of the string
|
||||||
|
str = strings.ReplaceAll(str, " ", "%")
|
||||||
|
str = "%" + str + "%"
|
||||||
|
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
// ExprID converts a list of uint64 values to a list of mysql.Expression values
|
// ExprID converts a list of uint64 values to a list of mysql.Expression values
|
||||||
func ExprID(ids []uint64) []mysql.Expression {
|
func ExprID(ids []uint64) []mysql.Expression {
|
||||||
expressions := make([]mysql.Expression, len(ids))
|
expressions := make([]mysql.Expression, len(ids))
|
||||||
|
|||||||
Reference in New Issue
Block a user