package dbx import ( "context" "database/sql" "errors" "fmt" "reflect" "slices" "strings" "gitea.auvem.com/go-toolkit/app" "github.com/fatih/color" "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/qrm" _ "github.com/go-sql-driver/mysql" ) // ModuleDBName is the name of the database module. const ModuleDBName = "database" var ( // ErrNoRows is returned when a query returns no rows. ErrNoRows = qrm.ErrNoRows // SQLDB stores the current SQL database handle. SQLDB *sql.DB // config stores the database connection configuration. config DBConfig // dbModule is the singleton module instance for the database connection. dbModule *app.Module ) // ModuleDB returns the database module with the provided configuration. func ModuleDB(cfg DBConfig, forceDebugLog bool) *app.Module { if dbModule != nil { panic("ModuleDB initialized multiple times") } config = cfg // store configuration at package level if forceDebugLog { config.DebugLog = true // force debug logging if requested } dbModule = app.NewModule(ModuleDBName, app.ModuleOpts{ Setup: setupDB, Teardown: teardownDB, }) return dbModule } // setupDB connects to the MySQL DB and initializes the goose migration provider. // If auto migrations are enabled in configuration, latest migrations are applied. func setupDB() error { if SQLDB != nil && SQLDB.Ping() == nil { dbModule.Logger().Warn("Database connection already established") return nil } var err error db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Path, config.Name) SQLDB, err = sql.Open("mysql", db_path) if err != nil { dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) return err } if err := SQLDB.Ping(); err != nil { dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) return err } SQLDB.SetMaxOpenConns(config.MaxConn) stats := SQLDB.Stats() dbModule.Logger().Info( "Connected to SQL database", "user", config.User, "name", config.Name, "path", config.Path, "maxConnections", stats.MaxOpenConnections, "currConnections", stats.OpenConnections, ) if config.DebugLog { initDBLogger() } return nil } // teardownDB closes the database connection. func teardownDB() { if SQLDB == nil { return } if err := SQLDB.Close(); err != nil { dbModule.Logger().Error("Couldn't close database", "err", err) } dbModule.Logger().Info("Closed database connection") } // Queryable interface is an SQL driver object that can execute SQL statements // for Jet. type Queryable interface { qrm.Queryable Query(string, ...any) (*sql.Rows, error) } // Executable interface is an SQL driver object that can execute SQL statements // for Jet. type Executable interface { qrm.Executable Exec(string, ...any) (sql.Result, error) } // ExecutableTx interface is an SQL driver object that implements the Executable // interface and can also begin a transaction. type ExecutableTx interface { Executable Begin() (*sql.Tx, error) } // MustQuery executes a query and returns an error if any. Filters errors for // db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. func MustQuery(sqlo Queryable, stmt mysql.Statement, dest any, notFoundErr error) error { err := stmt.Query(sqlo, dest) if err != nil { if errors.Is(err, qrm.ErrNoRows) { return notFoundErr } return err } return nil } // MustInsert requires at least one row to be affected by a query and returns the inserted ID func MustInsert(sqlo Executable, stmt mysql.InsertStatement) (uint64, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err } id, err := res.LastInsertId() if err != nil { return 0, err } if id < 0 { return 0, fmt.Errorf("MustInsert got invalid ID: %d", id) } return uint64(id), nil } // MustUpdate requires at least one row to be affected by a query func MustUpdate(sqlo Executable, stmt mysql.Statement, notFoundErr error) error { res, err := stmt.Exec(sqlo) if err != nil { return err } rows, err := res.RowsAffected() if err != nil { return err } if rows == 0 { return notFoundErr } return nil } // MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows func MightUpdateMany(sqlo Executable, stmt mysql.Statement) (int, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err } rows, err := res.RowsAffected() return int(rows), err } // SoftDelete sets the deleted_at column to the current time func SoftDelete(sqlo Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) { stmt := tbl.UPDATE().WHERE(conds) query := stmt.DebugSql() lines := strings.Split(query, "\n") for i, line := range lines { fmt.Println(i, line) } lines = slices.Insert(lines, 2, "SET deleted_at = NOW()") query = strings.Join(lines, "\n") res, err := sqlo.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } // NormalCols processes a list of columns and strips out any that implement any of // mysql.ColumnTimestamp, mysql.ColumnTime, or mysql.ColumnDate func NormalCols(cols ...mysql.Column) mysql.ColumnList { res := make(mysql.ColumnList, 0) for _, col := range cols { _, ok := col.(mysql.ColumnTimestamp) if !ok { _, ok = col.(mysql.ColumnTime) } if !ok { _, ok = col.(mysql.ColumnDate) } if !ok { res = append(res, col) } } return res } // ContainsCol checks if a column list contains a specific column. func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool { return slices.Contains(cols, col) } // DestName returns the name of the type passed as `destTypeStruct` as a string, // normalized for compatibility with the Jet QRM. func DestName(destTypeStruct interface{}, path ...string) string { v := reflect.ValueOf(destTypeStruct) for v.Kind() == reflect.Ptr { v = v.Elem() } destIdent := v.Type().String() destIdent = destIdent[strings.LastIndex(destIdent, ".")+1:] for i, p := range path { if v.Kind() != reflect.Struct { dbModule.Logger().Error("DestName: path parent is not a struct", "path", destIdent+"."+strings.Join(path[:i+1], ".")) return "" } v = v.FieldByName(p) if !v.IsValid() { dbModule.Logger().Error("DestName: field does not exist", "path", destIdent+"."+strings.Join(path[:i+1], ".")) return "" } destIdent += "." + p } return destIdent } // StringToFilter processes a string to be used as a filter in an SQL LIKE // statement. It replaces all spaces with % and adds % to the beginning and // end of the string. func StringToFilter(str string) string { // Remove any existing leading or trailing % characters str = strings.Trim(str, "%") // Replace all spaces with % and add % to the beginning and end of the string str = strings.ReplaceAll(str, " ", "%") str = "%" + str + "%" return str } // initDBLogger initializes the database statement logger func initDBLogger() { mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) { _, args := queryInfo.Statement.Sql() dbModule.Logger().Debug( "Executed SQL query", "args", args, "duration", queryInfo.Duration, "rows", queryInfo.RowsProcessed, "err", queryInfo.Err, ) lines := strings.Split(queryInfo.Statement.DebugSql(), "\n") for i, line := range lines { fmt.Printf("%s\t%s\n", color.CyanString(fmt.Sprintf("%03d", i)), line) } }) }