diff --git a/cancelreader.go b/cancelreader.go new file mode 100644 index 0000000..90b22f0 --- /dev/null +++ b/cancelreader.go @@ -0,0 +1,72 @@ +package tea + +import ( + "fmt" + "io" + "sync" +) + +var errCanceled = fmt.Errorf("read cancelled") + +// cancelReader is a io.Reader whose Read() calls can be cancelled without data +// being consumed. The cancelReader has to be closed. +type cancelReader interface { + io.ReadCloser + + // Cancel cancels ongoing and future reads an returns true if it succeeded. + Cancel() bool +} + +// fallbackCancelReader implements cancelReader but does not actually support +// cancelation during an ongoing Read() call. Thus, Cancel() always returns +// false. However, after calling Cancel(), new Read() calls immediately return +// errCanceled and don't consume any data anymore. +type fallbackCancelReader struct { + r io.Reader + cancelled bool +} + +// newFallbackCancelReader is a fallback for newCancelReader that cannot +// actually cancel an ongoing read but will immediately return on future reads +// if it has been cancelled. +func newFallbackCancelReader(reader io.Reader) (cancelReader, error) { + return &fallbackCancelReader{r: reader}, nil +} + +func (r *fallbackCancelReader) Read(data []byte) (int, error) { + if r.cancelled { + return 0, errCanceled + } + + return r.r.Read(data) +} + +func (r *fallbackCancelReader) Cancel() bool { + r.cancelled = true + + return false +} + +func (r *fallbackCancelReader) Close() error { + return nil +} + +// cancelMixin represents a goroutine-safe cancelation status. +type cancelMixin struct { + unsafeCancelled bool + lock sync.Mutex +} + +func (c *cancelMixin) isCancelled() bool { + c.lock.Lock() + defer c.lock.Unlock() + + return c.unsafeCancelled +} + +func (c *cancelMixin) setCancelled() { + c.lock.Lock() + defer c.lock.Unlock() + + c.unsafeCancelled = true +} diff --git a/cancelreader_bsd.go b/cancelreader_bsd.go new file mode 100644 index 0000000..0f6653a --- /dev/null +++ b/cancelreader_bsd.go @@ -0,0 +1,144 @@ +// +build darwin freebsd netbsd openbsd + +// nolint:revive +package tea + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + + "golang.org/x/sys/unix" +) + +// newkqueueCancelReader returns a reader and a cancel function. If the input reader +// is an *os.File, the cancel function can be used to interrupt a blocking call +// read call. In this case, the cancel function returns true if the call was +// cancelled successfully. If the input reader is not a *os.File, the cancel +// function does nothing and always returns false. The BSD and macOS +// implementation is based on the kqueue mechanism. +func newCancelReader(reader io.Reader) (cancelReader, error) { + file, ok := reader.(*os.File) + if !ok { + return newFallbackCancelReader(reader) + } + + // kqueue returns instantly when polling /dev/tty so fallback to select + if file.Name() == "/dev/tty" { + return newSelectCancelReader(reader) + } + + kQueue, err := unix.Kqueue() + if err != nil { + return nil, fmt.Errorf("create kqueue: %w", err) + } + + r := &kqueueCancelReader{ + file: file, + kQueue: kQueue, + } + + r.cancelSignalReader, r.cancelSignalWriter, err = os.Pipe() + if err != nil { + return nil, err + } + + unix.SetKevent(&r.kQueueEvents[0], int(file.Fd()), unix.EVFILT_READ, unix.EV_ADD) + unix.SetKevent(&r.kQueueEvents[1], int(r.cancelSignalReader.Fd()), unix.EVFILT_READ, unix.EV_ADD) + + return r, nil +} + +type kqueueCancelReader struct { + file *os.File + cancelSignalReader *os.File + cancelSignalWriter *os.File + cancelMixin + kQueue int + kQueueEvents [2]unix.Kevent_t +} + +func (r *kqueueCancelReader) Read(data []byte) (int, error) { + if r.isCancelled() { + return 0, errCanceled + } + + err := r.wait() + if err != nil { + if errors.Is(err, errCanceled) { + // remove signal from pipe + var b [1]byte + _, errRead := r.cancelSignalReader.Read(b[:]) + if errRead != nil { + return 0, fmt.Errorf("reading cancel signal: %w", errRead) + } + } + + return 0, err + } + + return r.file.Read(data) +} + +func (r *kqueueCancelReader) Cancel() bool { + r.setCancelled() + + // send cancel signal + _, err := r.cancelSignalWriter.Write([]byte{'c'}) + return err == nil +} + +func (r *kqueueCancelReader) Close() error { + var errMsgs []string + + // close kqueue + err := unix.Close(r.kQueue) + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing kqueue: %v", err)) + } + + // close pipe + err = r.cancelSignalWriter.Close() + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal writer: %v", err)) + } + + err = r.cancelSignalReader.Close() + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal reader: %v", err)) + } + + if len(errMsgs) > 0 { + return fmt.Errorf(strings.Join(errMsgs, ", ")) + } + + return nil +} + +func (r *kqueueCancelReader) wait() error { + events := make([]unix.Kevent_t, 1) + + for { + _, err := unix.Kevent(r.kQueue, r.kQueueEvents[:], events, nil) + if errors.Is(err, unix.EINTR) { + continue // try again if the syscall was interrupted + } + + if err != nil { + return fmt.Errorf("kevent: %w", err) + } + + break + } + + switch events[0].Ident { + case uint64(r.file.Fd()): + return nil + case uint64(r.cancelSignalReader.Fd()): + return errCanceled + } + + return fmt.Errorf("unknown error") +} diff --git a/cancelreader_default.go b/cancelreader_default.go new file mode 100644 index 0000000..3d06887 --- /dev/null +++ b/cancelreader_default.go @@ -0,0 +1,14 @@ +//go:build !darwin && !windows && !linux && !solaris && !freebsd && !netbsd && !openbsd +// +build !darwin,!windows,!linux,!solaris,!freebsd,!netbsd,!openbsd + +package tea + +import ( + "io" +) + +// newCancelReader returns a fallbackCancelReader that satisfies the +// cancelReader but does not actually support cancelation. +func newCancelReader(reader io.Reader) (cancelReader, error) { + return newFallbackCancelReader(reader) +} diff --git a/cancelreader_linux.go b/cancelreader_linux.go new file mode 100644 index 0000000..343f7d1 --- /dev/null +++ b/cancelreader_linux.go @@ -0,0 +1,152 @@ +//go:build linux +// +build linux + +// nolint:revive +package tea + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + + "golang.org/x/sys/unix" +) + +// newCancelReader returns a reader and a cancel function. If the input reader +// is an *os.File, the cancel function can be used to interrupt a blocking call +// read call. In this case, the cancel function returns true if the call was +// cancelled successfully. If the input reader is not a *os.File, the cancel +// function does nothing and always returns false. The linux implementation is +// based on the epoll mechanism. +func newCancelReader(reader io.Reader) (cancelReader, error) { + file, ok := reader.(*os.File) + if !ok { + return newFallbackCancelReader(reader) + } + + epoll, err := unix.EpollCreate1(0) + if err != nil { + return nil, fmt.Errorf("create epoll: %w", err) + } + + r := &epollCancelReader{ + file: file, + epoll: epoll, + } + + r.cancelSignalReader, r.cancelSignalWriter, err = os.Pipe() + if err != nil { + return nil, err + } + + err = unix.EpollCtl(epoll, unix.EPOLL_CTL_ADD, int(file.Fd()), &unix.EpollEvent{ + Events: unix.EPOLLIN, + Fd: int32(file.Fd()), + }) + if err != nil { + return nil, fmt.Errorf("add reader to epoll interrest list") + } + + err = unix.EpollCtl(epoll, unix.EPOLL_CTL_ADD, int(r.cancelSignalReader.Fd()), &unix.EpollEvent{ + Events: unix.EPOLLIN, + Fd: int32(r.cancelSignalReader.Fd()), + }) + if err != nil { + return nil, fmt.Errorf("add reader to epoll interrest list") + } + + return r, nil +} + +type epollCancelReader struct { + file *os.File + cancelSignalReader *os.File + cancelSignalWriter *os.File + cancelMixin + epoll int +} + +func (r *epollCancelReader) Read(data []byte) (int, error) { + if r.isCancelled() { + return 0, errCanceled + } + + err := r.wait() + if err != nil { + if errors.Is(err, errCanceled) { + // remove signal from pipe + var b [1]byte + _, readErr := r.cancelSignalReader.Read(b[:]) + if readErr != nil { + return 0, fmt.Errorf("reading cancel signal: %w", readErr) + } + } + + return 0, err + } + + return r.file.Read(data) +} + +func (r *epollCancelReader) Cancel() bool { + r.setCancelled() + + // send cancel signal + _, err := r.cancelSignalWriter.Write([]byte{'c'}) + return err == nil +} + +func (r *epollCancelReader) Close() error { + var errMsgs []string + + // close kqueue + err := unix.Close(r.epoll) + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing epoll: %v", err)) + } + + // close pipe + err = r.cancelSignalWriter.Close() + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal writer: %v", err)) + } + + err = r.cancelSignalReader.Close() + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal reader: %v", err)) + } + + if len(errMsgs) > 0 { + return fmt.Errorf(strings.Join(errMsgs, ", ")) + } + + return nil +} + +func (r *epollCancelReader) wait() error { + events := make([]unix.EpollEvent, 1) + + for { + _, err := unix.EpollWait(r.epoll, events, -1) + if errors.Is(err, unix.EINTR) { + continue // try again if the syscall was interrupted + } + + if err != nil { + return fmt.Errorf("kevent: %w", err) + } + + break + } + + switch events[0].Fd { + case int32(r.file.Fd()): + return nil + case int32(r.cancelSignalReader.Fd()): + return errCanceled + } + + return fmt.Errorf("unknown error") +} diff --git a/cancelreader_select.go b/cancelreader_select.go new file mode 100644 index 0000000..0276c8d --- /dev/null +++ b/cancelreader_select.go @@ -0,0 +1,137 @@ +//go:build solaris || darwin +// +build solaris darwin + +// nolint:revive +package tea + +import ( + "errors" + "fmt" + "io" + "os" + "strings" + + "golang.org/x/sys/unix" +) + +// newSelectCancelReader returns a reader and a cancel function. If the input +// reader is an *os.File, the cancel function can be used to interrupt a +// blocking call read call. In this case, the cancel function returns true if +// the call was cancelled successfully. If the input reader is not a *os.File or +// the file descriptor is 1024 or larger, the cancel function does nothing and +// always returns false. The generic unix implementation is based on the posix +// select syscall. +func newSelectCancelReader(reader io.Reader) (cancelReader, error) { + file, ok := reader.(*os.File) + if !ok || file.Fd() >= unix.FD_SETSIZE { + return newFallbackCancelReader(reader) + } + r := &selectCancelReader{file: file} + + var err error + + r.cancelSignalReader, r.cancelSignalWriter, err = os.Pipe() + if err != nil { + return nil, err + } + + return r, nil +} + +type selectCancelReader struct { + file *os.File + cancelSignalReader *os.File + cancelSignalWriter *os.File + cancelMixin +} + +func (r *selectCancelReader) Read(data []byte) (int, error) { + if r.isCancelled() { + return 0, errCanceled + } + + for { + err := waitForRead(r.file, r.cancelSignalReader) + if err != nil { + if errors.Is(err, unix.EINTR) { + continue // try again if the syscall was interrupted + } + + if errors.Is(err, errCanceled) { + // remove signal from pipe + var b [1]byte + _, readErr := r.cancelSignalReader.Read(b[:]) + if readErr != nil { + return 0, fmt.Errorf("reading cancel signal: %w", readErr) + } + } + + return 0, err + } + + return r.file.Read(data) + } +} + +func (r *selectCancelReader) Cancel() bool { + r.setCancelled() + + // send cancel signal + _, err := r.cancelSignalWriter.Write([]byte{'c'}) + return err == nil +} + +func (r *selectCancelReader) Close() error { + var errMsgs []string + + // close pipe + err := r.cancelSignalWriter.Close() + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal writer: %v", err)) + } + + err = r.cancelSignalReader.Close() + if err != nil { + errMsgs = append(errMsgs, fmt.Sprintf("closing cancel signal reader: %v", err)) + } + + if len(errMsgs) > 0 { + return fmt.Errorf(strings.Join(errMsgs, ", ")) + } + + return nil +} + +func waitForRead(reader *os.File, abort *os.File) error { + readerFd := int(reader.Fd()) + abortFd := int(abort.Fd()) + + maxFd := readerFd + if abortFd > maxFd { + maxFd = abortFd + } + + // this is a limitation of the select syscall + if maxFd >= unix.FD_SETSIZE { + return fmt.Errorf("cannot select on file descriptor %d which is larger than 1024", maxFd) + } + + fdSet := &unix.FdSet{} + fdSet.Set(int(reader.Fd())) + fdSet.Set(int(abort.Fd())) + + _, err := unix.Select(maxFd+1, fdSet, nil, nil, nil) + if err != nil { + return fmt.Errorf("select: %w", err) + } + + if fdSet.IsSet(abortFd) { + return errCanceled + } + + if fdSet.IsSet(readerFd) { + return nil + } + + return fmt.Errorf("select returned without setting a file descriptor") +} diff --git a/cancelreader_unix.go b/cancelreader_unix.go new file mode 100644 index 0000000..c69db61 --- /dev/null +++ b/cancelreader_unix.go @@ -0,0 +1,20 @@ +//go:build solaris +// +build solaris + +// nolint:revive +package tea + +import ( + "io" +) + +// newCancelReader returns a reader and a cancel function. If the input reader +// is an *os.File, the cancel function can be used to interrupt a blocking call +// read call. In this case, the cancel function returns true if the call was +// cancelled successfully. If the input reader is not a *os.File or the file +// descriptor is 1024 or larger, the cancel function does nothing and always +// returns false. The generic unix implementation is based on the posix select +// syscall. +func newCancelReader(reader io.Reader) (cancelReader, error) { + return newSelectCancelReader(reader) +} diff --git a/cancelreader_windows.go b/cancelreader_windows.go new file mode 100644 index 0000000..0d449a7 --- /dev/null +++ b/cancelreader_windows.go @@ -0,0 +1,239 @@ +//go:build windows +// +build windows + +package tea + +import ( + "fmt" + "io" + "os" + "syscall" + "time" + "unicode/utf16" + + "golang.org/x/sys/windows" +) + +var fileShareValidFlags uint32 = 0x00000007 + +// newCancelReader returns a reader and a cancel function. If the input reader +// is an *os.File with the same file descriptor as os.Stdin, the cancel function +// can be used to interrupt a blocking call read call. In this case, the cancel +// function returns true if the call was cancelled successfully. If the input +// reader is not a *os.File with the same file descriptor as os.Stdin, the +// cancel function does nothing and always returns false. The Windows +// implementation is based on WaitForMultipleObject with overlapping reads from +// CONIN$. +func newCancelReader(reader io.Reader) (cancelReader, error) { + if f, ok := reader.(*os.File); !ok || f.Fd() != os.Stdin.Fd() { + return newFallbackCancelReader(reader) + } + + // it is neccessary to open CONIN$ (NOT windows.STD_INPUT_HANDLE) in + // overlapped mode to be able to use it with WaitForMultipleObjects. + conin, err := windows.CreateFile( + &(utf16.Encode([]rune("CONIN$\x00"))[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, + fileShareValidFlags, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED, 0) + if err != nil { + return nil, fmt.Errorf("open CONIN$ in overlapping mode: %w", err) + } + + resetConsole, err := prepareConsole(conin) + if err != nil { + return nil, fmt.Errorf("prepare console: %w", err) + } + + // flush input, otherwise it can contain events which trigger + // WaitForMultipleObjects but which ReadFile cannot read, resulting in an + // un-cancelable read + err = flushConsoleInputBuffer(conin) + if err != nil { + return nil, fmt.Errorf("flush console input buffer: %w", err) + } + + cancelEvent, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return nil, fmt.Errorf("create stop event: %w", err) + } + + return &winCancelReader{ + conin: conin, + cancelEvent: cancelEvent, + resetConsole: resetConsole, + blockingReadSignal: make(chan struct{}, 1), + }, nil +} + +type winCancelReader struct { + conin windows.Handle + cancelEvent windows.Handle + cancelMixin + + resetConsole func() error + blockingReadSignal chan struct{} +} + +func (r *winCancelReader) Read(data []byte) (int, error) { + if r.isCancelled() { + return 0, errCanceled + } + + err := r.wait() + if err != nil { + return 0, err + } + + if r.isCancelled() { + return 0, errCanceled + } + + // windows.Read does not work on overlapping windows.Handles + return r.readAsync(data) +} + +// Cancel cancels ongoing and future Read() calls and returns true if the +// cancelation of the ongoing Read() was successful. On Windows Terminal, +// WaitForMultipleObjects sometimes immediately returns without input being +// available. In this case, graceful cancelation is not possible and Cancel() +// returns false. +func (r *winCancelReader) Cancel() bool { + r.setCancelled() + + select { + case r.blockingReadSignal <- struct{}{}: + err := windows.SetEvent(r.cancelEvent) + if err != nil { + return false + } + <-r.blockingReadSignal + case <-time.After(100 * time.Millisecond): + // Read() hangs in a GetOverlappedResult which is likely due to + // WaitForMultipleObjects returning without input being available + // so we cannot cancel this ongoing read. + return false + } + + return true +} + +func (r *winCancelReader) Close() error { + err := windows.CloseHandle(r.cancelEvent) + if err != nil { + return fmt.Errorf("closing cancel event handle: %w", err) + } + + err = r.resetConsole() + if err != nil { + return err + } + + err = windows.Close(r.conin) + if err != nil { + return fmt.Errorf("closing CONIN$") + } + + return nil +} + +func (r *winCancelReader) wait() error { + event, err := windows.WaitForMultipleObjects([]windows.Handle{r.conin, r.cancelEvent}, false, windows.INFINITE) + switch { + case windows.WAIT_OBJECT_0 <= event && event < windows.WAIT_OBJECT_0+2: + if event == windows.WAIT_OBJECT_0+1 { + return errCanceled + } + + if event == windows.WAIT_OBJECT_0 { + return nil + } + + return fmt.Errorf("unexpected wait object is ready: %d", event-windows.WAIT_OBJECT_0) + case windows.WAIT_ABANDONED <= event && event < windows.WAIT_ABANDONED+2: + return fmt.Errorf("abandoned") + case event == uint32(windows.WAIT_TIMEOUT): + return fmt.Errorf("timeout") + case event == windows.WAIT_FAILED: + return fmt.Errorf("failed") + default: + return fmt.Errorf("unexpected error: %w", error(err)) + } +} + +// readAsync is neccessary to read from a windows.Handle in overlapping mode. +func (r *winCancelReader) readAsync(data []byte) (int, error) { + hevent, err := windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return 0, fmt.Errorf("create event: %w", err) + } + + overlapped := windows.Overlapped{ + HEvent: hevent, + } + + var n uint32 + + err = windows.ReadFile(r.conin, data, &n, &overlapped) + if err != nil && err != windows.ERROR_IO_PENDING { + return int(n), err + } + + r.blockingReadSignal <- struct{}{} + err = windows.GetOverlappedResult(r.conin, &overlapped, &n, true) + if err != nil { + return int(n), nil + } + <-r.blockingReadSignal + + return int(n), nil +} + +func prepareConsole(input windows.Handle) (reset func() error, err error) { + var originalMode uint32 + + err = windows.GetConsoleMode(input, &originalMode) + if err != nil { + return nil, fmt.Errorf("get console mode: %w", err) + } + + var newMode uint32 + newMode &^= windows.ENABLE_ECHO_INPUT + newMode &^= windows.ENABLE_LINE_INPUT + newMode &^= windows.ENABLE_MOUSE_INPUT + newMode &^= windows.ENABLE_WINDOW_INPUT + newMode &^= windows.ENABLE_PROCESSED_INPUT + + newMode |= windows.ENABLE_EXTENDED_FLAGS + newMode |= windows.ENABLE_INSERT_MODE + newMode |= windows.ENABLE_QUICK_EDIT_MODE + + newMode &^= windows.ENABLE_VIRTUAL_TERMINAL_INPUT + + err = windows.SetConsoleMode(input, newMode) + if err != nil { + return nil, fmt.Errorf("set console mode: %w", err) + } + + return func() error { + err := windows.SetConsoleMode(input, originalMode) + if err != nil { + return fmt.Errorf("reset console mode: %w", err) + } + + return nil + }, nil +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + procFlushConsoleInputBuffer = modkernel32.NewProc("FlushConsoleInputBuffer") +) + +func flushConsoleInputBuffer(consoleInput windows.Handle) error { + r, _, e := syscall.Syscall(procFlushConsoleInputBuffer.Addr(), 1, + uintptr(consoleInput), 0, 0) + if r == 0 { + return error(e) + } + + return nil +} diff --git a/signals_unix.go b/signals_unix.go index b0ada12..c7ca029 100644 --- a/signals_unix.go +++ b/signals_unix.go @@ -4,6 +4,7 @@ package tea import ( + "context" "os" "os/signal" "syscall" @@ -14,15 +15,31 @@ import ( // listenForResize sends messages (or errors) when the terminal resizes. // Argument output should be the file descriptor for the terminal; usually // os.Stdout. -func listenForResize(output *os.File, msgs chan Msg, errs chan error) { +func listenForResize(ctx context.Context, output *os.File, msgs chan Msg, errs chan error, done chan struct{}) { sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGWINCH) + + defer func() { + signal.Stop(sig) + close(done) + }() + for { - <-sig + select { + case <-ctx.Done(): + return + case <-sig: + } + w, h, err := term.GetSize(int(output.Fd())) if err != nil { errs <- err } - msgs <- WindowSizeMsg{w, h} + + select { + case <-ctx.Done(): + return + case msgs <- WindowSizeMsg{w, h}: + } } } diff --git a/signals_windows.go b/signals_windows.go index a71b12f..115c99a 100644 --- a/signals_windows.go +++ b/signals_windows.go @@ -2,8 +2,14 @@ package tea -import "os" +import ( + "context" + "os" +) // listenForResize is not available on windows because windows does not // implement syscall.SIGWINCH. -func listenForResize(_ *os.File, _ chan Msg, _ chan error) {} +func listenForResize(ctx context.Context, output *os.File, msgs chan Msg, + errs chan error, done chan struct{}) { + close(done) +} diff --git a/tea.go b/tea.go index 8f52f0b..129fe8f 100644 --- a/tea.go +++ b/tea.go @@ -11,6 +11,7 @@ package tea import ( "context" + "errors" "fmt" "io" "os" @@ -18,6 +19,7 @@ import ( "runtime/debug" "sync" "syscall" + "time" "github.com/containerd/console" isatty "github.com/mattn/go-isatty" @@ -78,8 +80,7 @@ type Program struct { // treated as bits. These options can be set via various ProgramOptions. startupOptions startupOptions - mtx *sync.Mutex - done chan struct{} + mtx *sync.Mutex msgs chan Msg @@ -250,15 +251,39 @@ func NewProgram(model Model, opts ...ProgramOption) *Program { // Start initializes the program. func (p *Program) Start() error { p.msgs = make(chan Msg) - p.done = make(chan struct{}) var ( cmds = make(chan Cmd) errs = make(chan error) ) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // channels for managing goroutine lifecycles + var ( + readLoopDone = make(chan struct{}) + sigintLoopDone = make(chan struct{}) + cmdLoopDone = make(chan struct{}) + resizeLoopDone = make(chan struct{}) + initSignalDone = make(chan struct{}) + + waitForGoroutines = func(withReadLoop bool) { + if withReadLoop { + select { + case <-readLoopDone: + case <-time.After(500 * time.Millisecond): + // the read loop hangs, which means the input cancelReader's + // cancel function has returned true even though it was not + // able to cancel the read + } + } + <-cmdLoopDone + <-resizeLoopDone + <-sigintLoopDone + <-initSignalDone + } + ) + + ctx, cancelContext := context.WithCancel(context.Background()) + defer cancelContext() switch { case p.startupOptions.has(withInputTTY): @@ -267,6 +292,9 @@ func (p *Program) Start() error { if err != nil { return err } + + defer f.Close() // nolint:errcheck + p.input = f case !p.startupOptions.has(withCustomInput): @@ -287,6 +315,9 @@ func (p *Program) Start() error { if err != nil { return err } + + defer f.Close() // nolint:errcheck + p.input = f } @@ -297,7 +328,10 @@ func (p *Program) Start() error { go func() { sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) - defer signal.Stop(sig) + defer func() { + signal.Stop(sig) + close(sigintLoopDone) + }() select { case <-ctx.Done(): @@ -342,8 +376,14 @@ func (p *Program) Start() error { model := p.initialModel if initCmd := model.Init(); initCmd != nil { go func() { - cmds <- initCmd + defer close(initSignalDone) + select { + case cmds <- initCmd: + case <-ctx.Done(): + } }() + } else { + close(initSignalDone) } // Start renderer @@ -353,21 +393,37 @@ func (p *Program) Start() error { // Render initial view p.renderer.write(model.View()) + cancelReader, err := newCancelReader(p.input) + if err != nil { + return err + } + + defer cancelReader.Close() // nolint:errcheck + // Subscribe to user input if p.input != nil { go func() { + defer close(readLoopDone) + for { - msg, err := readInput(p.input) - if err != nil { - // If we get EOF just stop listening for input - if err == io.EOF { - break - } - errs <- err + if ctx.Err() != nil { + return } + + msg, err := readInput(cancelReader) + if err != nil { + if !errors.Is(err, io.EOF) && !errors.Is(err, errCanceled) { + errs <- err + } + + return + } + p.msgs <- msg } }() + } else { + defer close(readLoopDone) } if f, ok := p.output.(*os.File); ok { @@ -377,25 +433,44 @@ func (p *Program) Start() error { if err != nil { errs <- err } - p.msgs <- WindowSizeMsg{w, h} + + select { + case <-ctx.Done(): + case p.msgs <- WindowSizeMsg{w, h}: + } }() // Listen for window resizes - go listenForResize(f, p.msgs, errs) + go listenForResize(ctx, f, p.msgs, errs, resizeLoopDone) + } else { + close(resizeLoopDone) } // Process commands go func() { + defer close(cmdLoopDone) + for { select { - case <-p.done: + case <-ctx.Done(): + return case cmd := <-cmds: - if cmd != nil { - go func() { - p.msgs <- cmd() - }() + if cmd == nil { + continue } + + // Don't wait on these goroutines, otherwise the shutdown + // latency would get too large as a Cmd can run for some time + // (e.g. tick commands that sleep for half a second). It's not + // possible to cancel them so we'll have to leak the goroutine + // until Cmd returns. + go func() { + select { + case p.msgs <- cmd(): + case <-ctx.Done(): + } + }() } } }() @@ -404,13 +479,18 @@ func (p *Program) Start() error { for { select { case err := <-errs: + cancelContext() + waitForGoroutines(cancelReader.Cancel()) p.shutdown(false) + return err case msg := <-p.msgs: // Handle special internal messages switch msg := msg.(type) { case quitMsg: + cancelContext() + waitForGoroutines(cancelReader.Cancel()) p.shutdown(false) return nil @@ -479,7 +559,6 @@ func (p *Program) shutdown(kill bool) { } else { p.renderer.stop() } - close(p.done) p.ExitAltScreen() p.DisableMouseCellMotion() p.DisableMouseAllMotion() diff --git a/tty_unix.go b/tty_unix.go index 7bf0b45..974b9b5 100644 --- a/tty_unix.go +++ b/tty_unix.go @@ -28,7 +28,7 @@ func (p *Program) initInput() error { // program exits. func (p *Program) restoreInput() error { if p.console != nil { - return p.console.Close() + return p.console.Reset() } return nil }