refactor with improved developer ux in mind
- Most functions now live directly in the dbx package - dbxp and dbxm are now ONLY the few functions that cannot be shared
This commit is contained in:
276
dbx.go
276
dbx.go
@@ -1,14 +1,32 @@
|
||||
package dbx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"errors"
|
||||
|
||||
"gitea.auvem.com/go-toolkit/app"
|
||||
"gitea.auvem.com/go-toolkit/dbx/internal/dbxshared"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
)
|
||||
|
||||
// ModuleDBName is the name of the database module.
|
||||
const ModuleDBName = "database"
|
||||
const (
|
||||
// ModuleDBName is the name of the database module.
|
||||
ModuleDBName = "database"
|
||||
|
||||
// DialectPostgres is the PostgreSQL dialect.
|
||||
DialectPostgres Dialect = "postgres"
|
||||
// DialectMySQL is the MySQL dialect.
|
||||
DialectMySQL Dialect = "mysql"
|
||||
)
|
||||
|
||||
// Dialect is the SQL dialect used by the database connection.
|
||||
type Dialect string
|
||||
|
||||
// String implements the Stringer interface for Dialect.
|
||||
func (d Dialect) String() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// SQLOFunc is a function that returns a *sql.DB pointer.
|
||||
type SQLOFunc = func() *sql.DB
|
||||
@@ -34,16 +52,244 @@ type ExecutableTx interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
// StringToFilter processes a string to be used as a filter in an SQL LIKE
|
||||
// statement. It replaces all spaces with % and adds % to the beginning and
|
||||
// end of the string.
|
||||
func StringToFilter(str string) string {
|
||||
// Remove any existing leading or trailing % characters
|
||||
str = strings.Trim(str, "%")
|
||||
|
||||
// Replace all spaces with % and add % to the beginning and end of the string
|
||||
str = strings.ReplaceAll(str, " ", "%")
|
||||
str = "%" + str + "%"
|
||||
|
||||
return str
|
||||
// Statement is a common Jet statement for all SQL operations.
|
||||
type Statement interface {
|
||||
Query(db qrm.Queryable, destination any) error
|
||||
QueryContext(ctx context.Context, db qrm.Queryable, destination any) error
|
||||
Exec(db qrm.Executable) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error)
|
||||
}
|
||||
|
||||
// SelectStatement is a Jet statement that can be executed to fetch rows from the database.
|
||||
type SelectStatement interface {
|
||||
Statement
|
||||
LIMIT(limit int64) SelectStatement
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrNoRows is returned when a query returns no rows.
|
||||
ErrNoRows = qrm.ErrNoRows
|
||||
|
||||
// sqlDB stores the current SQL database handle.
|
||||
sqlDB *sql.DB
|
||||
|
||||
// config stores the database connection configuration.
|
||||
config *DBConfig
|
||||
|
||||
// dialect is the SQL dialect used by the database connection.
|
||||
dialect Dialect
|
||||
|
||||
// debugLog indicates whether debug logging is enabled.
|
||||
debugLog bool
|
||||
)
|
||||
|
||||
// SQLO returns the current SQL database handle.
|
||||
func SQLO() *sql.DB {
|
||||
dbxshared.DBModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database
|
||||
if sqlDB == nil {
|
||||
panic("SQL database not initialized")
|
||||
}
|
||||
return sqlDB
|
||||
}
|
||||
|
||||
// ModuleDB returns the existing database module, or panics if it has not
|
||||
// been initialized yet.
|
||||
func ModuleDB() *app.Module {
|
||||
if dbxshared.DBModule == nil {
|
||||
panic("ModuleDB not initialized yet")
|
||||
}
|
||||
return dbxshared.DBModule
|
||||
}
|
||||
|
||||
// InitModuleDB returns the database module with the provided configuration.
|
||||
func InitModuleDB(dialect Dialect, cfg *DBConfig, forceDebugLog bool) *app.Module {
|
||||
if dbxshared.DBModule != nil {
|
||||
panic("ModuleDB initialized multiple times")
|
||||
}
|
||||
if cfg == nil {
|
||||
panic("ModuleDB requires a non-nil DBConfig")
|
||||
}
|
||||
|
||||
config = cfg // store configuration at package level
|
||||
debugLog = cfg.DebugLog || forceDebugLog // force debug logging if requested
|
||||
|
||||
dbxshared.DBModule = app.NewModule(ModuleDBName, app.ModuleOpts{
|
||||
Setup: setupDB,
|
||||
Teardown: teardownDB,
|
||||
})
|
||||
|
||||
return dbxshared.DBModule
|
||||
}
|
||||
|
||||
// Fetch queries the database and returns the result as a slice. If the query
|
||||
// returns no rows, it returns an empty slice and no error.
|
||||
func Fetch[T any](sqlo Queryable, stmt SelectStatement) ([]*T, error) {
|
||||
var result []*T
|
||||
if err := stmt.Query(sqlo, &result); err != nil && !errors.Is(err, ErrNoRows) {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MustFetch queries the database and returns the result as a slice. If the query
|
||||
// returns no rows, it returns an empty slice and the desired error.
|
||||
func MustFetch[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) ([]*T, error) {
|
||||
result, err := Fetch[T](sqlo, stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, notFoundErr // return the desired error if no rows found
|
||||
}
|
||||
return result, nil // return the fetched results
|
||||
}
|
||||
|
||||
// FetchOne queries the database and returns a single result. If the query
|
||||
// returns no rows, it returns nil and no error.
|
||||
func FetchOne[T any](sqlo Queryable, stmt SelectStatement) (*T, error) {
|
||||
result, err := Fetch[T](sqlo, stmt.LIMIT(1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, nil // no rows found, return nil
|
||||
}
|
||||
return result[0], nil // return the first (and only) result
|
||||
}
|
||||
|
||||
// MustFetchOne queries the database and returns a single result. If the query
|
||||
// returns no rows, it returns nil and the desired error.
|
||||
func MustFetchOne[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) (*T, error) {
|
||||
result, err := MustFetch[T](sqlo, stmt.LIMIT(1), notFoundErr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result[0], nil // return the first (and only) result
|
||||
}
|
||||
|
||||
// Insert executes an insert statement, returning the last inserted ID or an
|
||||
// error if the insert fails.
|
||||
func Insert(sqlo Executable, stmt Statement) (uint64, error) {
|
||||
res, err := stmt.Exec(sqlo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if id < 1 {
|
||||
return 0, errors.New("inserted ID is less than 1")
|
||||
}
|
||||
|
||||
return uint64(id), nil
|
||||
}
|
||||
|
||||
// InsertReturning executes an insert statement that returns the inserted row.
|
||||
// The statement MUST be a Jet InsertStatement with a RETURNING clause. Returns
|
||||
// the inserted row object T or an error if the insert fails or no rows are returned.
|
||||
func InsertReturning[T any](sqlo Queryable, stmt Statement) (*T, error) {
|
||||
var result T
|
||||
err := stmt.Query(sqlo, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// Update executes an update statement, returning an error if the update fails.
|
||||
func Update(sqlo Executable, stmt Statement) error {
|
||||
_, err := stmt.Exec(sqlo)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateAffected executes an update statement and returns the number of rows
|
||||
// affected and an error if any.
|
||||
func UpdateAffected(sqlo Executable, stmt Statement) (int64, error) {
|
||||
res, err := stmt.Exec(sqlo)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rowsAffected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return rowsAffected, nil
|
||||
}
|
||||
|
||||
// UpdateReturning executes an update statement that returns the updated row.
|
||||
// The statement MUST be a Jet UpdateStatement with a RETURNING clause. Returns
|
||||
// the updated row object T or an error if the update fails or no rows are returned.
|
||||
func UpdateReturning[T any](sqlo Queryable, stmt Statement) (*T, error) {
|
||||
var result T
|
||||
err := stmt.Query(sqlo, &result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// setupDB connects to the database.
|
||||
func setupDB(m *app.Module) error {
|
||||
if sqlDB != nil && sqlDB.Ping() == nil {
|
||||
m.Logger().Warn("Database connection already established")
|
||||
return nil
|
||||
}
|
||||
|
||||
logArgs := []any{
|
||||
"user", config.User,
|
||||
"name", config.Name,
|
||||
"uri", config.URI,
|
||||
"dialect", dialect,
|
||||
}
|
||||
|
||||
var err error
|
||||
sqlDB, err = sql.Open(string(dialect), config.ConnectionString(dialect))
|
||||
if err != nil {
|
||||
logArgs = append(logArgs, "err", err)
|
||||
m.Logger().Error("Couldn't open SQL database", logArgs...)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
logArgs = append(logArgs, "err", err)
|
||||
m.Logger().Error("Couldn't ping SQL database", logArgs...)
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(config.MaxConn)
|
||||
|
||||
stats := sqlDB.Stats()
|
||||
m.Logger().Info(
|
||||
"Connected to SQL database",
|
||||
"user", config.User,
|
||||
"name", config.Name,
|
||||
"uri", config.URI,
|
||||
"maxConnections", stats.MaxOpenConnections,
|
||||
"currConnections", stats.OpenConnections,
|
||||
)
|
||||
|
||||
if debugLog {
|
||||
dbxshared.InitLogger(dialect) // initialize the logger for the specified dialect
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// teardownDB closes the database connection.
|
||||
func teardownDB(m *app.Module) error {
|
||||
if sqlDB == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
m.Logger().Error("Couldn't close database", "err", err)
|
||||
return err
|
||||
}
|
||||
m.Logger().Info("Closed database connection")
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user