diff --git a/dbx.go b/dbx.go index 15d4e52..026c08a 100644 --- a/dbx.go +++ b/dbx.go @@ -60,6 +60,12 @@ type Statement interface { ExecContext(ctx context.Context, db qrm.Executable) (sql.Result, error) } +// SelectStatement is a Jet statement that can be executed to fetch rows from the database. +type SelectStatement interface { + Statement + LIMIT(limit int64) SelectStatement +} + var ( // ErrNoRows is returned when a query returns no rows. ErrNoRows = qrm.ErrNoRows @@ -108,7 +114,7 @@ func ModuleDB(dialect Dialect, cfg *DBConfig, forceDebugLog bool) *app.Module { // Fetch queries the database and returns the result as a slice. If the query // returns no rows, it returns an empty slice and no error. -func Fetch[T any](sqlo Queryable, stmt Statement) ([]*T, error) { +func Fetch[T any](sqlo Queryable, stmt SelectStatement) ([]*T, error) { var result []*T if err := stmt.Query(sqlo, &result); err != nil && !errors.Is(err, ErrNoRows) { return nil, err @@ -118,7 +124,7 @@ func Fetch[T any](sqlo Queryable, stmt Statement) ([]*T, error) { // MustFetch queries the database and returns the result as a slice. If the query // returns no rows, it returns an empty slice and the desired error. -func MustFetch[T any](sqlo Queryable, stmt Statement, notFoundErr error) ([]*T, error) { +func MustFetch[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) ([]*T, error) { result, err := Fetch[T](sqlo, stmt) if err != nil { return nil, err @@ -129,6 +135,29 @@ func MustFetch[T any](sqlo Queryable, stmt Statement, notFoundErr error) ([]*T, return result, nil // return the fetched results } +// FetchOne queries the database and returns a single result. If the query +// returns no rows, it returns nil and no error. +func FetchOne[T any](sqlo Queryable, stmt SelectStatement) (*T, error) { + result, err := Fetch[T](sqlo, stmt.LIMIT(1)) + if err != nil { + return nil, err + } + if len(result) == 0 { + return nil, nil // no rows found, return nil + } + return result[0], nil // return the first (and only) result +} + +// MustFetchOne queries the database and returns a single result. If the query +// returns no rows, it returns nil and the desired error. +func MustFetchOne[T any](sqlo Queryable, stmt SelectStatement, notFoundErr error) (*T, error) { + result, err := MustFetch[T](sqlo, stmt.LIMIT(1), notFoundErr) + if err != nil { + return nil, err + } + return result[0], nil // return the first (and only) result +} + // Insert executes an insert statement, returning the last inserted ID or an // error if the insert fails. func Insert(sqlo Executable, stmt Statement) (uint64, error) {