diff --git a/commands_test.go b/commands_test.go index 749c6a9..e0d5ff3 100644 --- a/commands_test.go +++ b/commands_test.go @@ -80,3 +80,28 @@ func TestSequentially(t *testing.T) { }) } } + +func TestBatch(t *testing.T) { + t.Run("nil cmd", func(t *testing.T) { + if b := Batch(nil); b != nil { + t.Fatalf("expected nil, got %+v", b) + } + }) + t.Run("empty cmd", func(t *testing.T) { + if b := Batch(); b != nil { + t.Fatalf("expected nil, got %+v", b) + } + }) + t.Run("single cmd", func(t *testing.T) { + b := Batch(Quit)() + if l := len(b.(batchMsg)); l != 1 { + t.Fatalf("expected a []Cmd with len 1, got %d", l) + } + }) + t.Run("mixed nil cmds", func(t *testing.T) { + b := Batch(nil, Quit, nil, Quit, nil, nil)() + if l := len(b.(batchMsg)); l != 2 { + t.Fatalf("expected a []Cmd with len 2, got %d", l) + } + }) +} diff --git a/tea.go b/tea.go index fdf29ee..ee1bd92 100644 --- a/tea.go +++ b/tea.go @@ -119,11 +119,18 @@ type Program struct { // } // func Batch(cmds ...Cmd) Cmd { - if len(cmds) == 0 { + var validCmds []Cmd + for _, c := range cmds { + if c == nil { + continue + } + validCmds = append(validCmds, c) + } + if len(validCmds) == 0 { return nil } return func() Msg { - return batchMsg(cmds) + return batchMsg(validCmds) } }