package dbxm import ( "context" "database/sql" "errors" "fmt" "slices" "strings" "gitea.auvem.com/go-toolkit/app" "gitea.auvem.com/go-toolkit/dbx" dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic" "github.com/fatih/color" "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/qrm" ) var ( // ErrNoRows is returned when a query returns no rows. ErrNoRows = qrm.ErrNoRows // sqlDB stores the current SQL database handle. sqlDB *sql.DB // config stores the database connection configuration. config dbx.DBConfig // dbModule is the singleton module instance for the database connection. dbModule *app.Module ) // SQLO returns the current SQL database handle. func SQLO() *sql.DB { dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database if sqlDB == nil { panic("SQL database not initialized") } return sqlDB } // ModuleDB returns the database module with the provided configuration. func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module { if dbModule != nil { panic("ModuleDB initialized multiple times") } config = cfg // store configuration at package level if forceDebugLog { config.DebugLog = true // force debug logging if requested } dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{ Setup: setupDB, Teardown: teardownDB, }) return dbModule } // setupDB connects to the MySQL DB and initializes the goose migration provider. // If auto migrations are enabled in configuration, latest migrations are applied. func setupDB(_ *app.Module) error { if sqlDB != nil && sqlDB.Ping() == nil { dbModule.Logger().Warn("Database connection already established") return nil } var err error db_path := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", config.User, config.Password, config.Path, config.Name) sqlDB, err = sql.Open("mysql", db_path) if err != nil { dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) return err } if err := sqlDB.Ping(); err != nil { dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) return err } sqlDB.SetMaxOpenConns(config.MaxConn) stats := sqlDB.Stats() dbModule.Logger().Info( "Connected to SQL database", "user", config.User, "name", config.Name, "path", config.Path, "maxConnections", stats.MaxOpenConnections, "currConnections", stats.OpenConnections, ) if config.DebugLog { initDBLogger() } return nil } // teardownDB closes the database connection. func teardownDB(_ *app.Module) error { if sqlDB == nil { return nil } if err := sqlDB.Close(); err != nil { dbModule.Logger().Error("Couldn't close database", "err", err) return err } dbModule.Logger().Info("Closed database connection") return nil } // MustQuery executes a query and returns an error if any. Filters errors for // db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. func MustQuery(sqlo dbx.Queryable, stmt mysql.Statement, dest any, notFoundErr error) error { err := stmt.Query(sqlo, dest) if err != nil { if errors.Is(err, qrm.ErrNoRows) { return notFoundErr } return err } return nil } // MustInsert requires at least one row to be affected by a query and returns the inserted ID func MustInsert(sqlo dbx.Executable, stmt mysql.InsertStatement) (uint64, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err } id, err := res.LastInsertId() if err != nil { return 0, err } if id < 0 { return 0, fmt.Errorf("MustInsert got invalid ID: %d", id) } return uint64(id), nil } // MustUpdate requires at least one row to be affected by a query func MustUpdate(sqlo dbx.Executable, stmt mysql.Statement, notFoundErr error) error { res, err := stmt.Exec(sqlo) if err != nil { return err } rows, err := res.RowsAffected() if err != nil { return err } if rows == 0 { return notFoundErr } return nil } // MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows func MightUpdateMany(sqlo dbx.Executable, stmt mysql.Statement) (int, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err } rows, err := res.RowsAffected() return int(rows), err } // SoftDelete sets the deleted_at column to the current time func SoftDelete(sqlo dbx.Executable, tbl mysql.Table, conds mysql.BoolExpression) (int64, error) { stmt := tbl.UPDATE().WHERE(conds) query := stmt.DebugSql() lines := strings.Split(query, "\n") for i, line := range lines { fmt.Println(i, line) } lines = slices.Insert(lines, 2, "SET deleted_at = NOW()") query = strings.Join(lines, "\n") res, err := sqlo.Exec(query) if err != nil { return 0, err } return res.RowsAffected() } // NormalCols processes a list of columns and strips out any that implement any of // mysql.ColumnTimestamp, mysql.ColumnTime, or mysql.ColumnDate func NormalCols(cols ...mysql.Column) mysql.ColumnList { res := make(mysql.ColumnList, 0) for _, col := range cols { _, ok := col.(mysql.ColumnTimestamp) if !ok { _, ok = col.(mysql.ColumnTime) } if !ok { _, ok = col.(mysql.ColumnDate) } if !ok { res = append(res, col) } } return res } // ContainsCol checks if a column list contains a specific column. func ContainsCol(cols mysql.ColumnList, col mysql.Column) bool { return slices.Contains(cols, col) } // DestName returns the name of the type passed as `destTypeStruct` as a string, // normalized for compatibility with the Jet QRM. func DestName(destTypeStruct any, path ...string) string { return dbxgeneric.DestName(dbModule, destTypeStruct, path...) } // initDBLogger initializes the database statement logger func initDBLogger() { mysql.SetQueryLogger(func(ctx context.Context, queryInfo mysql.QueryInfo) { _, args := queryInfo.Statement.Sql() dbModule.Logger().Debug( "Executed SQL query", "args", args, "duration", queryInfo.Duration, "rows", queryInfo.RowsProcessed, "err", queryInfo.Err, ) lines := strings.Split(queryInfo.Statement.DebugSql(), "\n") for i, line := range lines { fmt.Printf("%s\t%s\n", color.CyanString(fmt.Sprintf("%03d", i)), line) } }) }