diff --git a/lifecycle.go b/lifecycle.go index 7591f65..1fb4320 100644 --- a/lifecycle.go +++ b/lifecycle.go @@ -54,14 +54,6 @@ func NewLifecycle(modules ...*Module) *Lifecycle { } } -// LifecycleToContext adds the Lifecycle to the context. -func LifecycleToContext(ctx context.Context, lifecycle *Lifecycle) context.Context { - if lifecycle == nil { - return ctx - } - return context.WithValue(ctx, lifecycleContextKey, lifecycle) -} - // LifecycleFromContext retrieves the Lifecycle from the context. Returns nil if not found. func LifecycleFromContext(ctx context.Context) *Lifecycle { if lifecycle, ok := ctx.Value(lifecycleContextKey).(*Lifecycle); ok { @@ -70,6 +62,14 @@ func LifecycleFromContext(ctx context.Context) *Lifecycle { return nil } +// Context adds the Lifecycle to a context and returns the new context. +func (app *Lifecycle) Context(ctx context.Context) context.Context { + if app == nil { + return ctx + } + return context.WithValue(ctx, lifecycleContextKey, app) +} + // WithOpts sets the options for the lifecycle. func (app *Lifecycle) WithOpts(opts LifecycleOpts) *Lifecycle { app.opts = opts @@ -101,6 +101,10 @@ func (app *Lifecycle) Logger() *slog.Logger { // if dependencies are not satisfied. Lifecycle.Teardown should always be run at the end // of the application lifecycle to ensure all resources are cleaned up properly. func (app *Lifecycle) Setup() error { + if app.setupCount > 0 { + return fmt.Errorf("lifecycle already set up, cannot set up again") + } + for _, mod := range app.modules { if err := app.setupSingle(nil, mod); err != nil { return err @@ -118,6 +122,10 @@ func (app *Lifecycle) Setup() error { // to ensure all resources are cleaned up properly. All module tear down // errors are returned as a single error (non-blocking). func (app *Lifecycle) Teardown() error { + if app.teardownCount > 0 { + return fmt.Errorf("lifecycle already torn down, cannot tear down again") + } + var err error for i := len(app.modules) - 1; i >= 0; i-- { if singleErr := app.teardownSingle(app.modules[i]); singleErr != nil { diff --git a/lifecycle_test.go b/lifecycle_test.go index 014be08..be2e22a 100644 --- a/lifecycle_test.go +++ b/lifecycle_test.go @@ -54,7 +54,7 @@ func TestLifecycleToFromContext(t *testing.T) { lc := NewLifecycle() // Add the Lifecycle to the context - ctx := LifecycleToContext(context.Background(), lc) + ctx := lc.Context(context.Background()) // Retrieve the Lifecycle from the context retrievedLC := LifecycleFromContext(ctx) @@ -63,7 +63,8 @@ func TestLifecycleToFromContext(t *testing.T) { assert.Equal(lc, retrievedLC, "expected retrieved Lifecycle to match original") // Test with nil Lifecycle - nilCtx := LifecycleToContext(context.Background(), nil) + var nilLifecycle *Lifecycle + nilCtx := nilLifecycle.Context(context.Background()) retrievedNilLC := LifecycleFromContext(nilCtx) assert.Nil(retrievedNilLC, "expected nil Lifecycle to return nil from context") } @@ -171,6 +172,22 @@ func TestLifecycle_Setup(t *testing.T) { } }) } + + // Test double setup + t.Run("double setup", func(t *testing.T) { + assert := assert.New(t) + + lc := NewLifecycle( + NewModule("module1", ModuleOpts{}), + ) + + err := lc.Setup() + assert.NoError(err, "expected first Setup to succeed") + + err = lc.Setup() + assert.Error(err, "expected second Setup to fail") + assert.Contains(err.Error(), "already set up", "expected error message to indicate already set up") + }) } func TestLifecycle_Teardown(t *testing.T) { @@ -254,6 +271,26 @@ func TestLifecycle_Teardown(t *testing.T) { } }) } + + // Test double teardown + t.Run("double teardown", func(t *testing.T) { + assert := assert.New(t) + + lc := NewLifecycle( + NewModule("module1", ModuleOpts{}), + ) + + // Fake setup for the module + lc.modules[0].loaded = true + lc.setupTracker[lc.modules[0].name] = 0 + + err := lc.Teardown() + assert.NoError(err, "expected first Teardown to succeed") + + err = lc.Teardown() + assert.Error(err, "expected second Teardown to fail") + assert.Contains(err.Error(), "already torn down", "expected error message to indicate already torn down") + }) } func TestLifecycle_Require(t *testing.T) {