Files
dbx/dbx.go
2026-01-27 17:23:06 -08:00

290 lines
8.0 KiB
Go

package dbx
import (
"context"
"database/sql"
"errors"
"gitea.auvem.com/go-toolkit/app"
"gitea.auvem.com/go-toolkit/dbx/internal/dbxshared"
"github.com/go-jet/jet/v2/qrm"
)
const (
// ModuleDBName is the name of the database module.
ModuleDBName = "database"
// DialectPostgres is the PostgreSQL dialect.
DialectPostgres Dialect = "postgres"
// DialectMySQL is the MySQL dialect.
DialectMySQL Dialect = "mysql"
)
// Dialect is the SQL dialect used by the database connection.
type Dialect string
// String implements the Stringer interface for Dialect.
func (d Dialect) String() string {
return string(d)
}
// SQLOFunc is a function that returns a *sql.DB pointer.
type SQLOFunc = func() *sql.DB
// Queryable interface is an SQL driver object that can execute SQL statements
// for Jet.
type Queryable interface {
qrm.Queryable
Query(string, ...any) (*sql.Rows, error)
}
// Executable interface is an SQL driver object that can execute SQL statements
// for Jet.
type Executable interface {
qrm.Executable
Exec(string, ...any) (sql.Result, error)
}
// ExecutableTx interface is an SQL driver object that implements the Executable
// interface and can also begin a transaction.
type ExecutableTx interface {
Executable
Begin() (*sql.Tx, error)
}
// Statement is a common Jet statement for all SQL operations.
type Statement interface {
Query(db qrm.Queryable, destination any) error
QueryContext(ctx context.Context, db qrm.Queryable, destination any) error
Exec(db qrm.Executable) (sql.Result, error)
ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error)
}
// dbState stores package-level state for the database connection.
type dbState struct {
// sqlDB stores the current SQL database handle.
sqlDB *sql.DB
// config stores the database connection configuration.
config *DBConfig
// dialect is the SQL dialect used by the database connection.
dialect Dialect
// debugLog indicates whether debug logging is enabled.
debugLog bool
}
var (
// ErrNoRows is returned when a query returns no rows.
ErrNoRows = qrm.ErrNoRows
state = dbState{}
)
// SQLO returns the current SQL database handle.
func SQLO() *sql.DB {
dbxshared.DBModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database
if state.sqlDB == nil {
panic("SQL database not initialized")
}
return state.sqlDB
}
// ModuleDB returns the database module with the provided configuration.
// dialect specifies the SQL dialect to use (e.g., DialectPostgres, DialectMySQL).
// config specifies the database connection configuration.
// forceDebugLog forces debug logging to be enabled regardless of the config setting.
func ModuleDB(dialect Dialect, config *DBConfig, forceDebugLog bool) *app.Module {
if dbxshared.DBModule != nil {
panic("ModuleDB initialized multiple times")
}
if config == nil {
panic("ModuleDB requires a non-nil DBConfig")
}
state.config = config // store configuration at package level
state.dialect = dialect // store dialect at package level
state.debugLog = config.DebugLog || forceDebugLog // force debug logging if requested
dbxshared.DBModule = app.NewModule(ModuleDBName, app.ModuleOpts{
Setup: setupDB,
Teardown: teardownDB,
})
return dbxshared.DBModule
}
// Fetch queries the database and returns the result as a slice. If the query
// returns no rows, it returns an empty slice and no error.
func Fetch[T any](sqlo Queryable, stmt Statement) ([]*T, error) {
var result []*T
if err := stmt.Query(sqlo, &result); err != nil && !errors.Is(err, ErrNoRows) {
return nil, err
}
return result, nil
}
// MustFetch queries the database and returns the result as a slice. If the query
// returns no rows, it returns an empty slice and the desired error.
func MustFetch[T any](sqlo Queryable, stmt Statement, notFoundErr error) ([]*T, error) {
result, err := Fetch[T](sqlo, stmt)
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, notFoundErr // return the desired error if no rows found
}
return result, nil // return the fetched results
}
// FetchOne queries the database and returns a single result. If the query
// returns no rows, it returns nil and no error.
func FetchOne[T any](sqlo Queryable, stmt Statement) (*T, error) {
result, err := Fetch[T](sqlo, stmt)
if err != nil {
return nil, err
}
if len(result) == 0 {
return nil, nil // no rows found, return nil
}
return result[0], nil // return the first (and only) result
}
// MustFetchOne queries the database and returns a single result. If the query
// returns no rows, it returns nil and the desired error.
func MustFetchOne[T any](sqlo Queryable, stmt Statement, notFoundErr error) (*T, error) {
result, err := MustFetch[T](sqlo, stmt, notFoundErr)
if err != nil {
return nil, err
}
return result[0], nil // return the first (and only) result
}
// Insert executes an insert statement, returning the last inserted ID or an
// error if the insert fails.
func Insert(sqlo Executable, stmt Statement) (uint64, error) {
res, err := stmt.Exec(sqlo)
if err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
if id < 1 {
return 0, errors.New("inserted ID is less than 1")
}
return uint64(id), nil
}
// InsertReturning executes an insert statement that returns the inserted row.
// The statement MUST be a Jet InsertStatement with a RETURNING clause. Returns
// the inserted row object T or an error if the insert fails or no rows are returned.
func InsertReturning[T any](sqlo Queryable, stmt Statement) (*T, error) {
var result T
err := stmt.Query(sqlo, &result)
if err != nil {
return nil, err
}
return &result, nil
}
// Update executes an update statement, returning an error if the update fails.
func Update(sqlo Executable, stmt Statement) error {
_, err := stmt.Exec(sqlo)
return err
}
// UpdateAffected executes an update statement and returns the number of rows
// affected and an error if any.
func UpdateAffected(sqlo Executable, stmt Statement) (int64, error) {
res, err := stmt.Exec(sqlo)
if err != nil {
return 0, err
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return rowsAffected, nil
}
// UpdateReturning executes an update statement that returns the updated row.
// The statement MUST be a Jet UpdateStatement with a RETURNING clause. Returns
// the updated row object T or an error if the update fails or no rows are returned.
func UpdateReturning[T any](sqlo Queryable, stmt Statement) (*T, error) {
var result T
err := stmt.Query(sqlo, &result)
if err != nil {
return nil, err
}
return &result, nil
}
// setupDB connects to the database.
func setupDB(m *app.Module) error {
if state.sqlDB != nil && state.sqlDB.Ping() == nil {
m.Logger().Warn("Database connection already established")
return nil
}
logArgs := []any{
"user", state.config.User,
"name", state.config.Name,
"uri", state.config.URI,
"dialect", state.dialect,
}
var err error
state.sqlDB, err = sql.Open(string(state.dialect), state.config.ConnectionString(state.dialect))
if err != nil {
logArgs = append(logArgs, "err", err)
m.Logger().Error("Couldn't open SQL database", logArgs...)
return err
}
if err := state.sqlDB.Ping(); err != nil {
logArgs = append(logArgs, "err", err)
m.Logger().Error("Couldn't ping SQL database", logArgs...)
return err
}
state.sqlDB.SetMaxOpenConns(state.config.MaxConn)
stats := state.sqlDB.Stats()
m.Logger().Info(
"Connected to SQL database",
"user", state.config.User,
"name", state.config.Name,
"uri", state.config.URI,
"maxConnections", stats.MaxOpenConnections,
"currConnections", stats.OpenConnections,
)
if state.debugLog {
dbxshared.InitLogger(state.dialect) // initialize the logger for the specified dialect
}
return nil
}
// teardownDB closes the database connection.
func teardownDB(m *app.Module) error {
if state.sqlDB == nil {
return nil
}
if err := state.sqlDB.Close(); err != nil {
m.Logger().Error("Couldn't close database", "err", err)
return err
}
m.Logger().Info("Closed database connection")
return nil
}