diff --git a/options.go b/options.go index 59d0685..d9ce42c 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,7 @@ package tea import ( + "context" "io" "github.com/muesli/termenv" @@ -14,6 +15,15 @@ import ( // p := NewProgram(model, WithInput(someInput), WithOutput(someOutput)) type ProgramOption func(*Program) +// WithContext lets you specify a context in which to run the Program. This is +// useful if you want to cancel the execution from outside. When a Program gets +// cancelled it will exit with an error ErrProgramKilled. +func WithContext(ctx context.Context) ProgramOption { + return func(p *Program) { + p.ctx = ctx + } +} + // WithOutput sets the output which, by default, is stdout. In most cases you // won't need to use this. func WithOutput(output io.Writer) ProgramOption { diff --git a/tea.go b/tea.go index d2c19ff..6736d6a 100644 --- a/tea.go +++ b/tea.go @@ -141,8 +141,14 @@ func NewProgram(model Model, opts ...ProgramOption) *Program { msgs: make(chan Msg), } + // A context can be provided with a ProgramOption, but if none was provided + // we'll use the default background context. + if p.ctx == nil { + p.ctx = context.Background() + } + // Initialize context and teardown channel. - p.ctx, p.cancel = context.WithCancel(context.Background()) + p.ctx, p.cancel = context.WithCancel(p.ctx) // Apply all options to the program. for _, opt := range opts { diff --git a/tea_test.go b/tea_test.go index 6140677..142f42d 100644 --- a/tea_test.go +++ b/tea_test.go @@ -2,6 +2,7 @@ package tea import ( "bytes" + "context" "sync/atomic" "testing" "time" @@ -97,6 +98,28 @@ func TestTeaKill(t *testing.T) { } } +func TestTeaContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var buf bytes.Buffer + var in bytes.Buffer + + m := &testModel{} + p := NewProgram(m, WithContext(ctx), WithInput(&in), WithOutput(&buf)) + go func() { + for { + time.Sleep(time.Millisecond) + if m.executed.Load() != nil { + cancel() + return + } + } + }() + + if _, err := p.Run(); err != ErrProgramKilled { + t.Fatalf("Expected %v, got %v", ErrProgramKilled, err) + } +} + func TestTeaBatchMsg(t *testing.T) { var buf bytes.Buffer var in bytes.Buffer