From 7e35ba0f3ab13e0abe3fb33d287056c4526a5c48 Mon Sep 17 00:00:00 2001 From: Elijah Duffy Date: Wed, 4 Jun 2025 18:21:30 -0700 Subject: [PATCH] add postgresql support --- pgx/pxg.go | 252 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 pgx/pxg.go diff --git a/pgx/pxg.go b/pgx/pxg.go new file mode 100644 index 0000000..084382e --- /dev/null +++ b/pgx/pxg.go @@ -0,0 +1,252 @@ +package dbxp + +import ( + "context" + "database/sql" + "errors" + "fmt" + "slices" + "strings" + + "gitea.auvem.com/go-toolkit/app" + "gitea.auvem.com/go-toolkit/dbx" + dbxgeneric "gitea.auvem.com/go-toolkit/dbx/generic" + "github.com/fatih/color" + "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/qrm" +) + +var ( + // ErrNoRows is returned when a query returns no rows. + ErrNoRows = qrm.ErrNoRows + + // sqlDB stores the current SQL database handle. + sqlDB *sql.DB + + // config stores the database connection configuration. + config dbx.DBConfig + + // dbModule is the singleton module instance for the database connection. + dbModule *app.Module +) + +// SQLO returns the current SQL database handle. +func SQLO() *sql.DB { + dbModule.RequireLoaded("dbx.SQLO requires database module") // ensure the module is loaded before accessing the database + if sqlDB == nil { + panic("SQL database not initialized") + } + return sqlDB +} + +// ModuleDB returns the database module with the provided configuration. +func ModuleDB(cfg dbx.DBConfig, forceDebugLog bool) *app.Module { + if dbModule != nil { + panic("ModuleDB initialized multiple times") + } + + config = cfg // store configuration at package level + + if forceDebugLog { + config.DebugLog = true // force debug logging if requested + } + + dbModule = app.NewModule(dbx.ModuleDBName, app.ModuleOpts{ + Setup: setupDB, + Teardown: teardownDB, + }) + + return dbModule +} + +// setupDB connects to the PostgreSQL DB and initializes the goose migration provider. +// If auto migrations are enabled in configuration, latest migrations are applied. +func setupDB(_ *app.Module) error { + if sqlDB != nil && sqlDB.Ping() == nil { + dbModule.Logger().Warn("Database connection already established") + return nil + } + + var err error + db_path := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", + config.User, config.Password, config.Path, config.Name) + + sqlDB, err = sql.Open("postgres", db_path) + if err != nil { + dbModule.Logger().Error("Couldn't open SQL database", "user", config.User, "name", config.Name, "err", err) + return err + } + + if err := sqlDB.Ping(); err != nil { + dbModule.Logger().Error("Couldn't ping SQL database", "user", config.User, "name", config.Name, "err", err) + return err + } + + sqlDB.SetMaxOpenConns(config.MaxConn) + + stats := sqlDB.Stats() + dbModule.Logger().Info( + "Connected to SQL database", + "user", config.User, + "name", config.Name, + "path", config.Path, + "maxConnections", stats.MaxOpenConnections, + "currConnections", stats.OpenConnections, + ) + + if config.DebugLog { + initDBLogger() + } + + return nil +} + +// teardownDB closes the database connection. +func teardownDB(_ *app.Module) error { + if sqlDB == nil { + return nil + } + + if err := sqlDB.Close(); err != nil { + dbModule.Logger().Error("Couldn't close database", "err", err) + return err + } + dbModule.Logger().Info("Closed database connection") + return nil +} + +// MustQuery executes a query and returns an error if any. Filters errors for +// db.ErrNoRows (qrm.ErrNoRows) and returns nil or the error provided in that case. +func MustQuery(sqlo dbx.Queryable, stmt postgres.Statement, dest any, notFoundErr error) error { + err := stmt.Query(sqlo, dest) + if err != nil { + if errors.Is(err, qrm.ErrNoRows) { + return notFoundErr + } + return err + } + return nil +} + +// MustInsert requires at least one row to be affected by a query and returns the inserted ID +func MustInsert(sqlo dbx.Executable, stmt postgres.InsertStatement) (uint64, error) { + res, err := stmt.Exec(sqlo) + if err != nil { + return 0, err + } + + id, err := res.LastInsertId() + if err != nil { + return 0, err + } + + if id < 0 { + return 0, fmt.Errorf("MustInsert got invalid ID: %d", id) + } + + return uint64(id), nil +} + +// MustUpdate requires at least one row to be affected by a query +func MustUpdate(sqlo dbx.Executable, stmt postgres.Statement, notFoundErr error) error { + res, err := stmt.Exec(sqlo) + if err != nil { + return err + } + + rows, err := res.RowsAffected() + if err != nil { + return err + } + + if rows == 0 { + return notFoundErr + } + + return nil +} + +// MightUpdateMany expects multiple rows to be affected by a query and returns this number of rows +func MightUpdateMany(sqlo dbx.Executable, stmt postgres.Statement) (int, error) { + res, err := stmt.Exec(sqlo) + if err != nil { + return 0, err + } + + rows, err := res.RowsAffected() + return int(rows), err +} + +// SoftDelete sets the deleted_at column to the current time +func SoftDelete(sqlo dbx.Executable, tbl postgres.Table, conds postgres.BoolExpression) (int64, error) { + stmt := tbl.UPDATE().WHERE(conds) + query := stmt.DebugSql() + + lines := strings.Split(query, "\n") + for i, line := range lines { + fmt.Println(i, line) + } + + lines = slices.Insert(lines, 2, "SET deleted_at = NOW()") + query = strings.Join(lines, "\n") + + res, err := sqlo.Exec(query) + if err != nil { + return 0, err + } + + return res.RowsAffected() +} + +// NormalCols processes a list of columns and strips out any that implement any of +// postgres.ColumnTimestamp, postgres.ColumnTime, or postgres.ColumnDate +func NormalCols(cols ...postgres.Column) postgres.ColumnList { + res := make(postgres.ColumnList, 0) + + for _, col := range cols { + _, ok := col.(postgres.ColumnTimestamp) + + if !ok { + _, ok = col.(postgres.ColumnTime) + } + if !ok { + _, ok = col.(postgres.ColumnDate) + } + + if !ok { + res = append(res, col) + } + } + + return res +} + +// ContainsCol checks if a column list contains a specific column. +func ContainsCol(cols postgres.ColumnList, col postgres.Column) bool { + return slices.Contains(cols, col) +} + +// DestName returns the name of the type passed as `destTypeStruct` as a string, +// normalized for compatibility with the Jet QRM. +func DestName(destTypeStruct any, path ...string) string { + return dbxgeneric.DestName(dbModule, destTypeStruct, path...) +} + +// initDBLogger initializes the database statement logger +func initDBLogger() { + postgres.SetQueryLogger(func(ctx context.Context, queryInfo postgres.QueryInfo) { + _, args := queryInfo.Statement.Sql() + dbModule.Logger().Debug( + "Executed SQL query", + "args", args, + "duration", queryInfo.Duration, + "rows", queryInfo.RowsProcessed, + "err", queryInfo.Err, + ) + + lines := strings.Split(queryInfo.Statement.DebugSql(), "\n") + for i, line := range lines { + fmt.Printf("%s\t%s\n", color.CyanString(fmt.Sprintf("%03d", i)), line) + } + }) +}