315 lines
7.7 KiB
Go
315 lines
7.7 KiB
Go
package dbx
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
|
|
"gitea.auvem.com/go-toolkit/app"
|
|
"github.com/fatih/color"
|
|
"github.com/go-jet/jet/v2/mysql"
|
|
"github.com/go-jet/jet/v2/qrm"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
// ModuleDBName is the name of the database module.
|
|
const ModuleDBName = "database"
|
|
|
|
var (
|
|
// ErrNoRows is returned when a query returns no rows.
|
|
ErrNoRows = qrm.ErrNoRows
|
|
|
|
// sqlDB stores the current SQL database handle.
|
|
sqlDB *sql.DB
|
|
|
|
// config stores the database connection configuration.
|
|
config DBConfig
|
|
|
|
// dbModule is the singleton module instance for the database connection.
|
|
dbModule *app.Module
|
|
)
|
|
|
|
// SQLO returns the current SQL database handle.
|
|
func SQLO() *sql.DB {
|
|
dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database
|
|
if sqlDB == nil {
|
|
panic("SQL database not initialized")
|
|
}
|
|
return sqlDB
|
|
}
|
|
|
|
// ModuleDB returns the database module with the provided configuration.
|
|
func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module {
|
|
if dbModule != nil {
|
|
panic("ModuleDB initialized multiple times")
|
|
}
|
|
|
|
config = cfg // store configuration at package level
|
|
|
|
if forceDebugLog {
|
|
config.DebugLog = true // force debug logging if requested
|
|
}
|
|
|
|
dbModule = app.NewModule(ModuleDBName, app.ModuleOpts{
|
|
Setup: setupDB,
|
|
Teardown: teardownDB,
|
|
})
|
|
|
|
return dbModule
|
|
}
|
|
|
|
// setupDB connects to the MySQL DB and initializes the goose migration provider.
|
|
// If auto migrations are enabled in configuration, latest migrations are applied.
|
|
func setupDB(_ *app.Module) error {
|
|
if sqlDB != nil && sqlDB.Ping() == nil {
|
|
dbModule.Logger().Warn("Database connection already established")
|
|
return nil
|
|
}
|
|
|
|
var err error
|
|
db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
|
config.User, config.Password, config.Path, config.Name)
|
|
|
|
sqlDB, err = sql.Open("mysql", db_path)
|
|
if err != nil {
|
|
dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err)
|
|
return err
|
|
}
|
|
|
|
if err := sqlDB.Ping(); err != nil {
|
|
dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err)
|
|
return err
|
|
}
|
|
|
|
sqlDB.SetMaxOpenConns(config.MaxConn)
|
|
|
|
stats := sqlDB.Stats()
|
|
dbModule.Logger().Info(
|
|
"Connected to SQL database",
|
|
"user", config.User,
|
|
"name", config.Name,
|
|
"path", config.Path,
|
|
"maxConnections", stats.MaxOpenConnections,
|
|
"currConnections", stats.OpenConnections,
|
|
)
|
|
|
|
if config.DebugLog {
|
|
initDBLogger()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// teardownDB closes the database connection.
|
|
func teardownDB(_ *app.Module) error {
|
|
if sqlDB == nil {
|
|
return nil
|
|
}
|
|
|
|
if err := sqlDB.Close(); err != nil {
|
|
dbModule.Logger().Error("Couldn't close database", "err", err)
|
|
return err
|
|
}
|
|
dbModule.Logger().Info("Closed database connection")
|
|
return nil
|
|
}
|
|
|
|
// Queryable interface is an SQL driver object that can execute SQL statements
|
|
// for Jet.
|
|
type Queryable interface {
|
|
qrm.Queryable
|
|
Query(string, ...any) (*sql.Rows, error)
|
|
}
|
|
|
|
// Executable interface is an SQL driver object that can execute SQL statements
|
|
// for Jet.
|
|
type Executable interface {
|
|
qrm.Executable
|
|
Exec(string, ...any) (sql.Result, error)
|
|
}
|
|
|
|
// ExecutableTx interface is an SQL driver object that implements the Executable
|
|
// interface and can also begin a transaction.
|
|
type ExecutableTx interface {
|
|
Executable
|
|
Begin() (*sql.Tx, error)
|
|
}
|
|
|
|
// MustQuery executes a query and returns an error if any. Filters errors for
|
|
// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case.
|
|
func MustQuery(sqlo Queryable, stmt mysql.Statement, dest any, notFoundErr error) error {
|
|
err := stmt.Query(sqlo, dest)
|
|
if err != nil {
|
|
if errors.Is(err, qrm.ErrNoRows) {
|
|
return notFoundErr
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MustInsert requires at least one row to be affected by a query and returns the inserted ID
|
|
func MustInsert(sqlo Executable, stmt mysql.InsertStatement) (uint64, error) {
|
|
res, err := stmt.Exec(sqlo)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
id, err := res.LastInsertId()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if id < 0 {
|
|
return 0, fmt.Errorf("MustInsert got invalid ID: %d", id)
|
|
}
|
|
|
|
return uint64(id), nil
|
|
}
|
|
|
|
// MustUpdate requires at least one row to be affected by a query
|
|
func MustUpdate(sqlo Executable, stmt mysql.Statement, notFoundErr error) error {
|
|
res, err := stmt.Exec(sqlo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rows, err := res.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if rows == 0 {
|
|
return notFoundErr
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows
|
|
func MightUpdateMany(sqlo Executable, stmt mysql.Statement) (int, error) {
|
|
res, err := stmt.Exec(sqlo)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
rows, err := res.RowsAffected()
|
|
return int(rows), err
|
|
}
|
|
|
|
// SoftDelete sets the deleted_at column to the current time
|
|
func SoftDelete(sqlo Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) {
|
|
stmt := tbl.UPDATE().WHERE(conds)
|
|
query := stmt.DebugSql()
|
|
|
|
lines := strings.Split(query, "\n")
|
|
for i, line := range lines {
|
|
fmt.Println(i, line)
|
|
}
|
|
|
|
lines = slices.Insert(lines, 2, "SET deleted_at = NOW()")
|
|
query = strings.Join(lines, "\n")
|
|
|
|
res, err := sqlo.Exec(query)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return res.RowsAffected()
|
|
}
|
|
|
|
// NormalCols processes a list of columns and strips out any that implement any of
|
|
// mysql.ColumnTimestamp, mysql.ColumnTime, or mysql.ColumnDate
|
|
func NormalCols(cols ...mysql.Column) mysql.ColumnList {
|
|
res := make(mysql.ColumnList, 0)
|
|
|
|
for _, col := range cols {
|
|
_, ok := col.(mysql.ColumnTimestamp)
|
|
|
|
if !ok {
|
|
_, ok = col.(mysql.ColumnTime)
|
|
}
|
|
if !ok {
|
|
_, ok = col.(mysql.ColumnDate)
|
|
}
|
|
|
|
if !ok {
|
|
res = append(res, col)
|
|
}
|
|
}
|
|
|
|
return res
|
|
}
|
|
|
|
// ContainsCol checks if a column list contains a specific column.
|
|
func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool {
|
|
return slices.Contains(cols, col)
|
|
}
|
|
|
|
// DestName returns the name of the type passed as `destTypeStruct` as a string,
|
|
// normalized for compatibility with the Jet QRM.
|
|
func DestName(destTypeStruct interface{}, path ...string) string {
|
|
v := reflect.ValueOf(destTypeStruct)
|
|
for v.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
}
|
|
|
|
destIdent := v.Type().String()
|
|
destIdent = destIdent[strings.LastIndex(destIdent, ".")+1:]
|
|
|
|
for i, p := range path {
|
|
if v.Kind() != reflect.Struct {
|
|
dbModule.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], "."))
|
|
return ""
|
|
}
|
|
|
|
v = v.FieldByName(p)
|
|
|
|
if !v.IsValid() {
|
|
dbModule.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], "."))
|
|
return ""
|
|
}
|
|
|
|
destIdent += "." + p
|
|
}
|
|
|
|
return destIdent
|
|
}
|
|
|
|
// StringToFilter processes a string to be used as a filter in an SQL LIKE
|
|
// statement. It replaces all spaces with % and adds % to the beginning and
|
|
// end of the string.
|
|
func StringToFilter(str string) string {
|
|
// Remove any existing leading or trailing % characters
|
|
str = strings.Trim(str, "%")
|
|
|
|
// Replace all spaces with % and add % to the beginning and end of the string
|
|
str = strings.ReplaceAll(str, " ", "%")
|
|
str = "%" + str + "%"
|
|
|
|
return str
|
|
}
|
|
|
|
// initDBLogger initializes the database statement logger
|
|
func initDBLogger() {
|
|
mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) {
|
|
_, args := queryInfo.Statement.Sql()
|
|
dbModule.Logger().Debug(
|
|
"Executed SQL query",
|
|
"args", args,
|
|
"duration", queryInfo.Duration,
|
|
"rows", queryInfo.RowsProcessed,
|
|
"err", queryInfo.Err,
|
|
)
|
|
|
|
lines := strings.Split(queryInfo.Statement.DebugSql(), "\n")
|
|
for i, line := range lines {
|
|
fmt.Printf("%s\t%s\n", color.CyanString(fmt.Sprintf("%03d", i)), line)
|
|
}
|
|
})
|
|
}
|