package dbx import ( "context" "database/sql" "errors" "gitea.auvem.com/go-toolkit/app" "gitea.auvem.com/go-toolkit/dbx/internal/dbxshared" "github.com/go-jet/jet/v2/qrm" ) const ( // ModuleDBName is the name of the database module. ModuleDBName = "database" // DialectPostgres is the PostgreSQL dialect. DialectPostgres Dialect = "postgres" // DialectMySQL is the MySQL dialect. DialectMySQL Dialect = "mysql" ) // Dialect is the SQL dialect used by the database connection. type Dialect string // String implements the Stringer interface for Dialect. func (d Dialect) String() string { return string(d) } // SQLOFunc is a function that returns a *sql.DB pointer. type SQLOFunc = func() *sql.DB // Queryable interface is an SQL driver object that can execute SQL statements // for Jet. type Queryable interface { qrm.Queryable Query(string, ...any) (*sql.Rows, error) } // Executable interface is an SQL driver object that can execute SQL statements // for Jet. type Executable interface { qrm.Executable Exec(string, ...any) (sql.Result, error) } // ExecutableTx interface is an SQL driver object that implements the Executable // interface and can also begin a transaction. type ExecutableTx interface { Executable Begin() (*sql.Tx, error) } // Statement is a common Jet statement for all SQL operations. type Statement interface { Query(db qrm.Queryable, destination any) error QueryContext(ctx context.Context, db qrm.Queryable, destination any) error Exec(db qrm.Executable) (sql.Result, error) ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) } // SelectStatement is a Jet statement that can be executed to fetch rows from the database. type SelectStatement interface { Statement LIMIT(limit int64) SelectStatement } var ( // ErrNoRows is returned when a query returns no rows. ErrNoRows = qrm.ErrNoRows // sqlDB stores the current SQL database handle. sqlDB *sql.DB // config stores the database connection configuration. config *DBConfig // dialect is the SQL dialect used by the database connection. dialect Dialect // debugLog indicates whether debug logging is enabled. debugLog bool ) // SQLO returns the current SQL database handle. func SQLO() *sql.DB { dbxshared.DBModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database if sqlDB == nil { panic("SQL database not initialized") } return sqlDB } // ModuleDB returns the existing database module, or panics if it has not // been initialized yet. func ModuleDB() *app.Module { if dbxshared.DBModule == nil { panic("ModuleDB not initialized yet") } return dbxshared.DBModule } // InitModuleDB returns the database module with the provided configuration. func InitModuleDB(dialect Dialect, cfg *DBConfig, forceDebugLog bool) *app.Module { if dbxshared.DBModule != nil { panic("ModuleDB initialized multiple times") } if cfg == nil { panic("ModuleDB requires a non-nil DBConfig") } config = cfg // store configuration at package level debugLog = cfg.DebugLog || forceDebugLog // force debug logging if requested dbxshared.DBModule = app.NewModule(ModuleDBName, app.ModuleOpts{ Setup: setupDB, Teardown: teardownDB, }) return dbxshared.DBModule } // Fetch queries the database and returns the result as a slice. If the query // returns no rows, it returns an empty slice and no error. func Fetch[T any](sqlo Queryable, stmt SelectStatement) ([]*T, error) { var result []*T if err := stmt.Query(sqlo, &result); err != nil && !errors.Is(err, ErrNoRows) { return nil, err } return result, nil } // MustFetch queries the database and returns the result as a slice. If the query // returns no rows, it returns an empty slice and the desired error. func MustFetch[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) ([]*T, error) { result, err := Fetch[T](sqlo, stmt) if err != nil { return nil, err } if len(result) == 0 { return nil, notFoundErr // return the desired error if no rows found } return result, nil // return the fetched results } // FetchOne queries the database and returns a single result. If the query // returns no rows, it returns nil and no error. func FetchOne[T any](sqlo Queryable, stmt SelectStatement) (*T, error) { result, err := Fetch[T](sqlo, stmt.LIMIT(1)) if err != nil { return nil, err } if len(result) == 0 { return nil, nil // no rows found, return nil } return result[0], nil // return the first (and only) result } // MustFetchOne queries the database and returns a single result. If the query // returns no rows, it returns nil and the desired error. func MustFetchOne[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) (*T, error) { result, err := MustFetch[T](sqlo, stmt.LIMIT(1), notFoundErr) if err != nil { return nil, err } return result[0], nil // return the first (and only) result } // Insert executes an insert statement, returning the last inserted ID or an // error if the insert fails. func Insert(sqlo Executable, stmt Statement) (uint64, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err } id, err := res.LastInsertId() if err != nil { return 0, err } if id < 1 { return 0, errors.New("inserted ID is less than 1") } return uint64(id), nil } // InsertReturning executes an insert statement that returns the inserted row. // The statement MUST be a Jet InsertStatement with a RETURNING clause. Returns // the inserted row object T or an error if the insert fails or no rows are returned. func InsertReturning[T any](sqlo Queryable, stmt Statement) (*T, error) { var result T err := stmt.Query(sqlo, &result) if err != nil { return nil, err } return &result, nil } // Update executes an update statement, returning an error if the update fails. func Update(sqlo Executable, stmt Statement) error { _, err := stmt.Exec(sqlo) return err } // UpdateAffected executes an update statement and returns the number of rows // affected and an error if any. func UpdateAffected(sqlo Executable, stmt Statement) (int64, error) { res, err := stmt.Exec(sqlo) if err != nil { return 0, err } rowsAffected, err := res.RowsAffected() if err != nil { return 0, err } return rowsAffected, nil } // UpdateReturning executes an update statement that returns the updated row. // The statement MUST be a Jet UpdateStatement with a RETURNING clause. Returns // the updated row object T or an error if the update fails or no rows are returned. func UpdateReturning[T any](sqlo Queryable, stmt Statement) (*T, error) { var result T err := stmt.Query(sqlo, &result) if err != nil { return nil, err } return &result, nil } // setupDB connects to the database. func setupDB(m *app.Module) error { if sqlDB != nil && sqlDB.Ping() == nil { m.Logger().Warn("Database connection already established") return nil } logArgs := []any{ "user", config.User, "name", config.Name, "uri", config.URI, "dialect", dialect, } var err error sqlDB, err = sql.Open(string(dialect), config.ConnectionString(dialect)) if err != nil { logArgs = append(logArgs, "err", err) m.Logger().Error("Couldn't open SQL database", logArgs...) return err } if err := sqlDB.Ping(); err != nil { logArgs = append(logArgs, "err", err) m.Logger().Error("Couldn't ping SQL database", logArgs...) return err } sqlDB.SetMaxOpenConns(config.MaxConn) stats := sqlDB.Stats() m.Logger().Info( "Connected to SQL database", "user", config.User, "name", config.Name, "uri", config.URI, "maxConnections", stats.MaxOpenConnections, "currConnections", stats.OpenConnections, ) if debugLog { dbxshared.InitLogger(dialect) // initialize the logger for the specified dialect } return nil } // teardownDB closes the database connection. func teardownDB(m *app.Module) error { if sqlDB == nil { return nil } if err := sqlDB.Close(); err != nil { m.Logger().Error("Couldn't close database", "err", err) return err } m.Logger().Info("Closed database connection") return nil }