From b1e7f42ab0625c298c70f3dcd1f04da979a2e8be Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Fri, 21 Oct 2022 18:18:56 +0200 Subject: [PATCH] fix(key): invert the control loop Instead of reading messages in an array and then sending them into a channel, this version of key.go writes to the channel directly. --- key.go | 38 ++++++++++++++++++------------------ key_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++------ tty.go | 26 +++++++------------------ 3 files changed, 76 insertions(+), 43 deletions(-) diff --git a/key.go b/key.go index bbd533c..74f85ec 100644 --- a/key.go +++ b/key.go @@ -1,12 +1,11 @@ package tea import ( + "context" "fmt" "io" "regexp" "unicode/utf8" - - "github.com/mattn/go-localereader" ) // KeyMsg contains information about a keypress. KeyMsgs are always sent to @@ -539,27 +538,30 @@ func (u unknownCSISequenceMsg) String() string { var spaceRunes = []rune{' '} -// readInputs reads keypress and mouse inputs from a TTY and returns messages +// readInputs reads keypress and mouse inputs from a TTY and produces messages // containing information about the key or mouse events accordingly. -func readInputs(input io.Reader) ([]Msg, error) { +func readInputs(ctx context.Context, msgs chan<- Msg, input io.Reader) error { var buf [256]byte - input = localereader.NewReader(input) + for { + // Read and block. + numBytes, err := input.Read(buf[:]) + if err != nil { + return err + } + b := buf[:numBytes] - // Read and block - numBytes, err := input.Read(buf[:]) - if err != nil { - return nil, err + var i, w int + for i, w = 0, 0; i < len(b); i += w { + var msg Msg + w, msg = detectOneMsg(b[i:]) + select { + case msgs <- msg: + case <-ctx.Done(): + return ctx.Err() + } + } } - b := buf[:numBytes] - - var msgs []Msg - for i, w := 0, 0; i < len(b); i += w { - var msg Msg - w, msg = detectOneMsg(b[i:]) - msgs = append(msgs, msg) - } - return msgs, nil } var unknownCSIRe = regexp.MustCompile(`^\x1b\[[\x30-\x3f]*[\x20-\x2f]*[\x40-\x7e]`) diff --git a/key_test.go b/key_test.go index 1af6a3e..d466b05 100644 --- a/key_test.go +++ b/key_test.go @@ -2,13 +2,17 @@ package tea import ( "bytes" + "context" + "errors" "flag" "fmt" + "io" "math/rand" "reflect" "runtime" "sort" "strings" + "sync" "testing" "time" ) @@ -442,12 +446,7 @@ func TestReadInput(t *testing.T) { for i, td := range testData { t.Run(fmt.Sprintf("%d: %s", i, td.keyname), func(t *testing.T) { - msgs, err := readInputs(bytes.NewReader(td.in)) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Compute the title for the event sequence. + msgs := testReadInputs(t, bytes.NewReader(td.in)) var buf strings.Builder for i, msg := range msgs { if i > 0 { @@ -459,6 +458,7 @@ func TestReadInput(t *testing.T) { fmt.Fprintf(&buf, "%#v:%T", msg, msg) } } + title := buf.String() if title != td.keyname { t.Errorf("expected message titles:\n %s\ngot:\n %s", td.keyname, title) @@ -475,6 +475,49 @@ func TestReadInput(t *testing.T) { } } +func testReadInputs(t *testing.T, input io.Reader) []Msg { + // We'll check that the input reader finishes at the end + // without error. + var wg sync.WaitGroup + var inputErr error + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + wg.Wait() + if inputErr != nil && !errors.Is(inputErr, io.EOF) { + t.Fatalf("unexpected input error: %v", inputErr) + } + }() + + // The messages we're consuming. + msgsC := make(chan Msg) + + // Start the reader in the background. + wg.Add(1) + go func() { + defer wg.Done() + inputErr = readInputs(ctx, msgsC, input) + msgsC <- nil + }() + + var msgs []Msg +loop: + for { + select { + case msg := <-msgsC: + if msg == nil { + // end of input marker for the test. + break loop + } + msgs = append(msgs, msg) + case <-time.After(2 * time.Second): + t.Errorf("timeout waiting for input event") + break loop + } + } + return msgs +} + // randTest defines the test input and expected output for a sequence // of interleaved control sequences and control characters. type randTest struct { diff --git a/tty.go b/tty.go index 3ab6639..715d542 100644 --- a/tty.go +++ b/tty.go @@ -7,6 +7,7 @@ import ( "time" isatty "github.com/mattn/go-isatty" + localereader "github.com/mattn/go-localereader" "github.com/muesli/cancelreader" "golang.org/x/term" ) @@ -71,25 +72,12 @@ func (p *Program) initCancelReader() error { func (p *Program) readLoop() { defer close(p.readLoopDone) - for { - if p.ctx.Err() != nil { - return - } - - msgs, err := readInputs(p.cancelReader) - if err != nil { - if !errors.Is(err, io.EOF) && !errors.Is(err, cancelreader.ErrCanceled) { - select { - case <-p.ctx.Done(): - case p.errs <- err: - } - } - - return - } - - for _, msg := range msgs { - p.msgs <- msg + input := localereader.NewReader(p.cancelReader) + err := readInputs(p.ctx, p.msgs, input) + if !errors.Is(err, io.EOF) && !errors.Is(err, cancelreader.ErrCanceled) { + select { + case <-p.ctx.Done(): + case p.errs <- err: } } }