Files
dbx/dbxp/pxg.go
2025-06-05 14:21:59 -07:00

253 lines
6.2 KiB
Go

package dbxp
import (
"context"
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"gitea.auvem.com/go-toolkit/app"
"gitea.auvem.com/go-toolkit/dbx"
dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic"
"github.com/fatih/color"
"github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/qrm"
)
var (
// ErrNoRows is returned when a query returns no rows.
ErrNoRows = qrm.ErrNoRows
// sqlDB stores the current SQL database handle.
sqlDB *sql.DB
// config stores the database connection configuration.
config dbx.DBConfig
// dbModule is the singleton module instance for the database connection.
dbModule *app.Module
)
// SQLO returns the current SQL database handle.
func SQLO() *sql.DB {
dbModule.RequireLoaded("dbxp.SQLO requires database module") // ensure the module is loaded before accessing the database
if sqlDB == nil {
panic("SQL database not initialized")
}
return sqlDB
}
// ModuleDB returns the database module with the provided configuration.
func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module {
if dbModule != nil {
panic("ModuleDB initialized multiple times")
}
config = cfg // store configuration at package level
if forceDebugLog {
config.DebugLog = true // force debug logging if requested
}
dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{
Setup: setupDB,
Teardown: teardownDB,
})
return dbModule
}
// setupDB connects to the PostgreSQL DB and initializes the goose migration provider.
// If auto migrations are enabled in configuration, latest migrations are applied.
func setupDB(_ *app.Module) error {
if sqlDB != nil && sqlDB.Ping() == nil {
dbModule.Logger().Warn("Database connection already established")
return nil
}
var err error
db_path := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable",
config.User, config.Password, config.Path, config.Name)
sqlDB, err = sql.Open("postgres", db_path)
if err != nil {
dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err)
return err
}
if err := sqlDB.Ping(); err != nil {
dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err)
return err
}
sqlDB.SetMaxOpenConns(config.MaxConn)
stats := sqlDB.Stats()
dbModule.Logger().Info(
"Connected to SQL database",
"user", config.User,
"name", config.Name,
"path", config.Path,
"maxConnections", stats.MaxOpenConnections,
"currConnections", stats.OpenConnections,
)
if config.DebugLog {
initDBLogger()
}
return nil
}
// teardownDB closes the database connection.
func teardownDB(_ *app.Module) error {
if sqlDB == nil {
return nil
}
if err := sqlDB.Close(); err != nil {
dbModule.Logger().Error("Couldn't close database", "err", err)
return err
}
dbModule.Logger().Info("Closed database connection")
return nil
}
// MustQuery executes a query and returns an error if any. Filters errors for
// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case.
func MustQuery(sqlo dbx.Queryable, stmt postgres.Statement, dest any, notFoundErr error) error {
err := stmt.Query(sqlo, dest)
if err != nil {
if errors.Is(err, qrm.ErrNoRows) {
return notFoundErr
}
return err
}
return nil
}
// MustInsert requires at least one row to be affected by a query and returns the inserted ID
func MustInsert(sqlo dbx.Executable, stmt postgres.InsertStatement) (uint64, error) {
res, err := stmt.Exec(sqlo)
if err != nil {
return 0, err
}
id, err := res.LastInsertId()
if err != nil {
return 0, err
}
if id < 0 {
return 0, fmt.Errorf("MustInsert got invalid ID: %d", id)
}
return uint64(id), nil
}
// MustUpdate requires at least one row to be affected by a query
func MustUpdate(sqlo dbx.Executable, stmt postgres.Statement, notFoundErr error) error {
res, err := stmt.Exec(sqlo)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return notFoundErr
}
return nil
}
// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows
func MightUpdateMany(sqlo dbx.Executable, stmt postgres.Statement) (int, error) {
res, err := stmt.Exec(sqlo)
if err != nil {
return 0, err
}
rows, err := res.RowsAffected()
return int(rows), err
}
// SoftDelete sets the deleted_at column to the current time
func SoftDelete(sqlo dbx.Executable, tbl postgres.Table, conds postgres.BoolExpression) (int64, error) {
stmt := tbl.UPDATE().WHERE(conds)
query := stmt.DebugSql()
lines := strings.Split(query, "\n")
for i, line := range lines {
fmt.Println(i, line)
}
lines = slices.Insert(lines, 2, "SET deleted_at = NOW()")
query = strings.Join(lines, "\n")
res, err := sqlo.Exec(query)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// NormalCols processes a list of columns and strips out any that implement any of
// postgres.ColumnTimestamp, postgres.ColumnTime, or postgres.ColumnDate
func NormalCols(cols ...postgres.Column) postgres.ColumnList {
res := make(postgres.ColumnList, 0)
for _, col := range cols {
_, ok := col.(postgres.ColumnTimestamp)
if !ok {
_, ok = col.(postgres.ColumnTime)
}
if !ok {
_, ok = col.(postgres.ColumnDate)
}
if !ok {
res = append(res, col)
}
}
return res
}
// ContainsCol checks if a column list contains a specific column.
func ContainsCol(cols postgres.ColumnList, col postgres.Column) bool {
return slices.Contains(cols, col)
}
// DestName returns the name of the type passed as `destTypeStruct` as a string,
// normalized for compatibility with the Jet QRM.
func DestName(destTypeStruct any, path ...string) string {
return dbxgeneric.DestName(dbModule, destTypeStruct, path...)
}
// initDBLogger initializes the database statement logger
func initDBLogger() {
postgres.SetQueryLogger(func(ctx context.Context, queryInfo postgres.QueryInfo) {
_, args := queryInfo.Statement.Sql()
dbModule.Logger().Debug(
"Executed SQL query",
"args", args,
"duration", queryInfo.Duration,
"rows", queryInfo.RowsProcessed,
"err", queryInfo.Err,
)
lines := strings.Split(queryInfo.Statement.DebugSql(), "\n")
for i, line := range lines {
fmt.Printf("%s\t%s\n", color.CyanString(fmt.Sprintf("%03d", i)), line)
}
})
}