From 6ab100cc46c2249dbd14aade42d8dab1831113e4 Mon Sep 17 00:00:00 2001 From: Elijah Duffy Date: Wed, 28 Jan 2026 11:18:11 -0800 Subject: [PATCH] upgrade to go 1.25 & port helpers Removed ExprID and ExprEnum in favour of generics. --- dbx.go | 17 +++ go.mod | 3 +- go.sum | 2 + internal/dbxshared/dbxshared.go | 2 +- utility.go | 189 ++++++++++++++++++++++++++++++-- utility_test.go | 182 ++++++++++++++++++++++++++++++ 6 files changed, 383 insertions(+), 12 deletions(-) create mode 100644 utility_test.go diff --git a/dbx.go b/dbx.go index 65b7b69..0813968 100644 --- a/dbx.go +++ b/dbx.go @@ -52,6 +52,20 @@ type ExecutableTx interface { Begin() (*sql.Tx, error) } +// QueryExec interface is an SQL driver object that can execute SQL statements for Jet +// and query results. +type QueryExec interface { + Queryable + Executable +} + +// QueryExecTx interface is an SQL driver object that can execute SQL statements for +// Jet, query results, and begin a transaction. +type QueryExecTx interface { + Queryable + ExecutableTx +} + // Statement is a common Jet statement for all SQL operations. type Statement interface { Query(db qrm.Queryable, destination any) error @@ -79,6 +93,9 @@ var ( // ErrNoRows is returned when a query returns no rows. ErrNoRows = qrm.ErrNoRows + // ErrValueIsZero is returned when an expected value is missing. + ErrValueIsZero = errors.New("value is zero-value for type") + state = dbState{} ) diff --git a/go.mod b/go.mod index 52436d2..d53be75 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,13 @@ module gitea.auvem.com/go-toolkit/dbx -go 1.24.0 +go 1.25.0 require ( gitea.auvem.com/go-toolkit/app v0.0.0-20250530181559-231561c92698 github.com/fatih/color v1.18.0 github.com/go-jet/jet/v2 v2.13.0 github.com/segmentio/ksuid v1.0.4 + golang.org/x/exp v0.0.0-20260112195511-716be5621a96 ) require ( diff --git a/go.sum b/go.sum index 6529fe6..39d7fa1 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= diff --git a/internal/dbxshared/dbxshared.go b/internal/dbxshared/dbxshared.go index 111d9a9..ddaab69 100644 --- a/internal/dbxshared/dbxshared.go +++ b/internal/dbxshared/dbxshared.go @@ -9,7 +9,7 @@ import ( // normalized for compatibility with the Jet QRM. func DestName(destTypeStruct any, path ...string) string { v := reflect.ValueOf(destTypeStruct) - for v.Kind() == reflect.Ptr { + for v.Kind() == reflect.Pointer { v = v.Elem() } diff --git a/utility.go b/utility.go index 1c1ccec..40687d4 100644 --- a/utility.go +++ b/utility.go @@ -1,9 +1,13 @@ package dbx import ( + "fmt" + "reflect" "strings" + "time" "github.com/go-jet/jet/v2/mysql" + "golang.org/x/exp/constraints" ) // StringToFilter processes a string to be used as a filter in an SQL LIKE @@ -20,20 +24,185 @@ func StringToFilter(str string) string { return str } -// ExprID converts a list of uint64 values to a list of mysql.Expression values -func ExprID(ids []uint64) []mysql.Expression { - expressions := make([]mysql.Expression, len(ids)) - for i, id := range ids { - expressions[i] = mysql.Uint64(id) +// ExprValues converts a list of values to a list of mysql.Expression values using +// function f to transform the values (mysql.String for strings, mysql.Uint64, etc). +func ExprValues[T any](values []T, f func(T) mysql.Expression) []mysql.Expression { + expressions := make([]mysql.Expression, len(values)) + for i, v := range values { + expressions[i] = f(v) } return expressions } -// ExprEnum converts a list of uint8 values to a list of mysql.Expression values -func ExprEnum(enums []uint8) []mysql.Expression { - expressions := make([]mysql.Expression, len(enums)) - for i, enum := range enums { - expressions[i] = mysql.Uint8(enum) +// ExprStringers converts a list of fmt.Stringers to a list of mysql.Expression values. +func ExprStringers(values []fmt.Stringer) []mysql.Expression { + expressions := make([]mysql.Expression, len(values)) + for i, v := range values { + expressions[i] = mysql.String(v.String()) } return expressions } + +// NowPtr returns a pointer to the current time. +func NowPtr() *time.Time { + now := time.Now() + return &now +} + +// Ptr returns a pointer to the given value of any scalar type. Returns nil if the value is a zero value. +func Ptr[T any](val T) *T { + if reflect.ValueOf(val).IsZero() { + return nil + } + return &val +} + +// Val returns the value of the pointer to a scalar type, or the zero value if the pointer is nil. +func Val[T any](ptr *T) T { + if ptr == nil { + var zero T + return zero + } + return *ptr +} + +// TrimPtr trims the whitespace from a pointer to a string and returns nil only if the pointer is nil. +func TrimPtr(s *string) *string { + if s == nil { + return nil + } + trimmed := strings.TrimSpace(*s) + return &trimmed +} + +// TrimPtrToNil trims the whitespace from a pointer to a string and returns nil +// if the resulting string is empty or if the pointer is nil. +func TrimPtrToNil(s *string) *string { + if s == nil { + return nil + } + trimmed := strings.TrimSpace(*s) + if trimmed == "" { + return nil + } + return &trimmed +} + +// IsZero checks if a pointer references the zero value of a given type and +// returns an error if this condition is met, otherwise returns nil if the +// pointer is nil or the value is not zero. +func IsZero[T any](ptr *T) error { + if ptr == nil { + return nil + } + if reflect.ValueOf(*ptr).IsZero() { + return ErrValueIsZero + } + return nil +} + +// ApplyPtr compares the existing value with a new value and returns the updated value if they differ. +// If the new value is nil, the existing value is retained. If the new value is a zero-value, the +// existing value is NOT retained, it will be set to nil. If the value is changed, targetColumn is pushed +// to updatedColumns. +func ApplyPtr[T constraints.Float | constraints.Integer | string | bool]( + existing *T, + newVal *T, + updatedColumns *mysql.ColumnList, + targetColumn mysql.Column, +) *T { + if newVal == nil { + return existing + } + if reflect.ValueOf(*newVal).IsZero() { + newVal = nil + } + if newVal == nil && existing == nil || newVal != nil && existing != nil && *existing == *newVal { + return existing + } + *updatedColumns = append(*updatedColumns, targetColumn) + return newVal +} + +// ApplyComplexPtr compares the existing value with a new value and returns the updated value if they differ. +// The new value may be of a different type (e.g. existing is uint16 and new is uint64), but it will be +// converted to match the current type resulting in potential loss of data. If the new value is nil, the +// existing value is retained. If the new value is a zero-value, the existing value is NOT retained, it +// will be set to nil. If the value is changed, targetColumn is pushed to updatedColumns. +func ApplyComplexPtr[ + Existing constraints.Float | constraints.Integer, + New constraints.Float | constraints.Integer, +]( + existing *Existing, + newVal *New, + updatedColumns *mysql.ColumnList, + targetColumn mysql.Column, +) *Existing { + if newVal == nil { + return existing + } + cast := Existing(*newVal) // Convert new value to existing type + if existing != nil && *existing == cast { + return existing + } + if reflect.ValueOf(cast).IsZero() { + if existing != nil { + *updatedColumns = append(*updatedColumns, targetColumn) + } + return nil + } else { + *updatedColumns = append(*updatedColumns, targetColumn) + return &cast + } +} + +type ApplyInterface[T any] interface { + Equal(T) bool + IsZero() bool +} + +// ApplyInterfacePtr compares the existing value with a new value and returns the updated value if +// they differ. Comparable types must have IsZero and Equal methods. If the new value is nil, the +// existing value is retained. If the new value is a zero-value, the existing value is NOT retained, +// it will be set to nil. If the value is changed, targetColumn is pushed to updatedColumns. +func ApplyInterfacePtr[T ApplyInterface[T]]( + existing *T, + newVal *T, + updatedColumns *mysql.ColumnList, + targetColumn mysql.Column, +) *T { + if newVal == nil { + return existing + } + if existing != nil && (*existing).Equal(*newVal) { + return existing + } + if (*newVal).IsZero() { + if existing != nil { + *updatedColumns = append(*updatedColumns, targetColumn) + } + return nil + } else { + *updatedColumns = append(*updatedColumns, targetColumn) + return newVal + } +} + +// ApplyVal compares the existing value with a pointer to a new value and returns the updated value if they +// differ. If the new value is nil, the existing value is retained. If the value is changed, targetColumn +// is pushed to updatedColumns +func ApplyVal[T constraints.Float | constraints.Integer | string | bool]( + existing T, + newVal *T, + updatedColumns *mysql.ColumnList, + targetColumn mysql.Column, +) T { + if newVal == nil { + return existing + } + if existing == *newVal { + return existing + } + *updatedColumns = append(*updatedColumns, targetColumn) + return *newVal +} diff --git a/utility_test.go b/utility_test.go new file mode 100644 index 0000000..61b4d82 --- /dev/null +++ b/utility_test.go @@ -0,0 +1,182 @@ +package dbx + +import ( + "testing" + "time" + + "github.com/go-jet/jet/v2/mysql" + "github.com/stretchr/testify/assert" + "golang.org/x/exp/constraints" +) + +func Test_ApplyPtr(t *testing.T) { + targetCol := mysql.StringColumn("test") + var empty string + + cases := []struct { + name string + current *string + new *string + expected *string + length int + }{ + {"apply updated value", Ptr("hello"), Ptr("world"), Ptr("world"), 1}, + {"apply same value", Ptr("hello"), Ptr("hello"), Ptr("hello"), 0}, + {"apply new value to nil", nil, Ptr("hello"), Ptr("hello"), 1}, + {"apply nil to existing value", Ptr("hello"), nil, Ptr("hello"), 0}, + {"apply empty string to existing value", Ptr("hello"), &empty, nil, 1}, + {"apply empty string to nil", nil, &empty, nil, 0}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + modified := make(mysql.ColumnList, 0) + updated := ApplyPtr(c.current, c.new, &modified, targetCol) + assert.Len(t, modified, c.length) + + if c.expected == nil { + assert.Nil(t, updated, "Expected updated value to be nil") + } else { + assert.NotNil(t, updated, "Expected updated value to be not nil") + if updated != nil { + assert.Equal(t, *c.expected, *updated, "Expected updated value to be equal to new value") + } + } + }) + } +} + +func applyComplexPtr_tc[ + Current constraints.Float | constraints.Integer, + New constraints.Float | constraints.Integer, +]( + t *testing.T, + name string, + current *Current, + new *New, + expected *Current, + length int, +) { + targetCol := mysql.IntegerColumn("test") + t.Run(name, func(t *testing.T) { + modified := make(mysql.ColumnList, 0) + updated := ApplyComplexPtr(current, new, &modified, targetCol) + assert.Len(t, modified, length) + + if expected == nil { + assert.Nil(t, updated, "Expected updated value to be nil") + } else { + assert.NotNil(t, updated, "Expected updated value to be not nil") + if updated != nil { + assert.Equal(t, *expected, *updated, "Expected updated value to be equal to new value") + } + } + }) +} + +func Test_ApplyComplexPtr(t *testing.T) { + var empty int + applyComplexPtr_tc(t, "apply updated value", Ptr(1), Ptr(2), Ptr(2), 1) + applyComplexPtr_tc(t, "apply same value", Ptr(1), Ptr(1), Ptr(1), 0) + applyComplexPtr_tc(t, "apply new value to nil", nil, Ptr(2), Ptr(2), 1) + applyComplexPtr_tc[int, int](t, "apply nil to existing value", Ptr(1), nil, Ptr(1), 0) + applyComplexPtr_tc(t, "apply zero value to existing value", Ptr(1), &empty, nil, 1) + applyComplexPtr_tc(t, "apply different type to existing value", Ptr[uint64](1), Ptr[uint16](2), Ptr[uint64](2), 1) + applyComplexPtr_tc[int](t, "apply zero value to nil", nil, &empty, nil, 0) +} + +func applyInterfacePtr_tc[ + T ApplyInterface[T], +]( + t *testing.T, + name string, + current *T, + new *T, + expected *T, + length int, +) { + targetCol := mysql.StringColumn("test") + t.Run(name, func(t *testing.T) { + modified := make(mysql.ColumnList, 0) + updated := ApplyInterfacePtr(current, new, &modified, targetCol) + assert.Len(t, modified, length) + + if expected == nil { + assert.Nil(t, updated, "Expected updated value to be nil") + } else { + assert.NotNil(t, updated, "Expected updated value to be not nil") + if updated != nil { + assert.Equal(t, *expected, *updated, "Expected updated value to be equal to new value") + } + } + }) +} + +func Test_ApplyInterfacePtr(t *testing.T) { + var zero time.Time + applyInterfacePtr_tc(t, "apply updated value", + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + Ptr(time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)), + Ptr(time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)), + 1, + ) + applyInterfacePtr_tc(t, "apply same value", + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + 0, + ) + applyInterfacePtr_tc(t, "apply new value to nil", + nil, + Ptr(time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)), + Ptr(time.Date(2023, 10, 2, 0, 0, 0, 0, time.UTC)), + 1, + ) + applyInterfacePtr_tc(t, "apply nil to existing value", + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + nil, + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + 0, + ) + applyInterfacePtr_tc(t, "apply zero value to existing value", + Ptr(time.Date(2023, 10, 1, 0, 0, 0, 0, time.UTC)), + &zero, + nil, + 1, + ) + applyInterfacePtr_tc(t, "apply zero value to nil", + nil, + &zero, + nil, + 0, + ) +} + +func Test_ApplyVal(t *testing.T) { + targetCol := mysql.StringColumn("test") + var empty string + + cases := []struct { + name string + current string + new *string + expected string + length int + }{ + {"apply updated value", "hello", Ptr("world"), "world", 1}, + {"apply same value", "hello", Ptr("hello"), "hello", 0}, + {"apply new value to empty value", "", Ptr("hello"), "hello", 1}, + {"apply nil to existing value", "hello", nil, "hello", 0}, + {"apply empty string to existing value", "hello", &empty, "", 1}, + {"apply empty string to empty value", "", &empty, "", 0}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + modified := make(mysql.ColumnList, 0) + updated := ApplyVal(c.current, c.new, &modified, targetCol) + assert.Equal(t, c.expected, updated, "Expected updated value to be equal to new value") + assert.Len(t, modified, c.length) + }) + } +}