diff --git a/internal/pspy/pspy.go b/internal/pspy/pspy.go index 70cd7c6..233ec58 100644 --- a/internal/pspy/pspy.go +++ b/internal/pspy/pspy.go @@ -38,9 +38,16 @@ type chans struct { func Start(cfg *config.Config, b *Bindings, sigCh chan os.Signal) chan struct{} { b.Logger.Infof("Config: %+v", cfg) + abort := make(chan struct{}, 1) + abort <- struct{}{} - initFSW(b.FSW, cfg.RDirs, cfg.Dirs, b.Logger) - triggerCh, fsEventCh := startFSW(b.FSW, b.Logger, cfg.DrainFor) + if !initFSW(b.FSW, cfg.RDirs, cfg.Dirs, b.Logger, sigCh) { + return abort + } + triggerCh, fsEventCh, ok := startFSW(b.FSW, b.Logger, cfg.DrainFor, sigCh) + if !ok { + return abort + } psEventCh := startPSS(b.PSS, b.Logger, triggerCh) @@ -83,27 +90,29 @@ func printOutput(cfg *config.Config, b *Bindings, chans *chans) chan struct{} { return exit } -func initFSW(fsw FSWatcher, rdirs, dirs []string, logger Logger) { +func initFSW(fsw FSWatcher, rdirs, dirs []string, logger Logger, sigCh <-chan os.Signal) bool { errCh, doneCh := fsw.Init(rdirs, dirs) for { select { + case <-sigCh: + return false case <-doneCh: - return + return true case err := <-errCh: logger.Errorf(true, "initializing fs watcher: %v", err) } } } -func startFSW(fsw FSWatcher, logger Logger, drainFor time.Duration) (triggerCh chan struct{}, fsEventCh chan string) { +func startFSW(fsw FSWatcher, logger Logger, drainFor time.Duration, sigCh <-chan os.Signal) (triggerCh chan struct{}, fsEventCh chan string, ok bool) { triggerCh, fsEventCh, errCh := fsw.Run() go logErrors(errCh, logger) // ignore all file system events created on startup logger.Infof("Draining file system events due to startup...") - drainEventsFor(triggerCh, fsEventCh, drainFor) + ok = drainEventsFor(triggerCh, fsEventCh, drainFor, sigCh) logger.Infof("done") - return triggerCh, fsEventCh + return } func startPSS(pss PSScanner, logger Logger, triggerCh chan struct{}) (psEventCh chan psscanner.PSEvent) { @@ -128,13 +137,15 @@ func logErrors(errCh chan error, logger Logger) { } } -func drainEventsFor(triggerCh chan struct{}, eventCh chan string, d time.Duration) { +func drainEventsFor(triggerCh chan struct{}, eventCh chan string, d time.Duration, sigCh <-chan os.Signal) bool { for { select { + case <-sigCh: + return false case <-triggerCh: case <-eventCh: case <-time.After(d): - return + return true } } } diff --git a/internal/pspy/pspy_test.go b/internal/pspy/pspy_test.go index 7141d6c..b6080e5 100644 --- a/internal/pspy/pspy_test.go +++ b/internal/pspy/pspy_test.go @@ -17,24 +17,55 @@ func TestInitFSW(t *testing.T) { fsw := newMockFSWatcher() rdirs := make([]string, 0) dirs := make([]string, 0) + sigCh := make(chan os.Signal) go func() { fsw.initErrCh <- errors.New("error1") fsw.initErrCh <- errors.New("error2") close(fsw.initDoneCh) }() - initFSW(fsw, rdirs, dirs, l) + if !initFSW(fsw, rdirs, dirs, l, sigCh) { + t.Error("unexpected return value") + } expectMessage(t, l.Error, "initializing fs watcher: error1") expectMessage(t, l.Error, "initializing fs watcher: error2") expectClosed(t, fsw.initDoneCh) } +func TestInitFSWInterrupt(t *testing.T) { + l := newMockLogger() + fsw := newMockFSWatcher() + rdirs := make([]string, 0) + dirs := make([]string, 0) + sigCh := make(chan os.Signal, 0) + done := make(chan struct{}) + + go func() { + <-time.After(100 * time.Millisecond) + sigCh <- os.Interrupt + }() + + go func() { + if initFSW(fsw, rdirs, dirs, l, sigCh) { + t.Error("unexpected return value") + } + done <- struct{}{} + }() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Error("timout") + } +} + // very flaky test... refactor code! func TestStartFSW(t *testing.T) { l := newMockLogger() fsw := newMockFSWatcher() drainFor := 100 * time.Millisecond + sigCh := make(chan os.Signal) go func() { fsw.runTriggerCh <- struct{}{} // trigger sent while draining @@ -47,7 +78,10 @@ func TestStartFSW(t *testing.T) { }() // sends no events and triggers from the drain phase - triggerCh, fsEventCh := startFSW(fsw, l, drainFor) + triggerCh, fsEventCh, ok := startFSW(fsw, l, drainFor, sigCh) + if !ok { + t.Error("unexpected return value") + } expectMessage(t, l.Info, "Draining file system events due to startup...") expectMessage(t, l.Error, "ERROR: error sent while draining") expectMessage(t, l.Info, "done") @@ -55,6 +89,32 @@ func TestStartFSW(t *testing.T) { expectMessage(t, fsEventCh, "event sent after draining") } +func TestStartFSWInterrupt(t *testing.T) { + l := newMockLogger() + fsw := newMockFSWatcher() + drainFor := 500 * time.Millisecond + sigCh := make(chan os.Signal) + done := make(chan struct{}) + + go func() { + <-time.After(100 * time.Millisecond) + sigCh <- os.Interrupt + }() + + go func() { + if _, _, ok := startFSW(fsw, l, drainFor, sigCh); ok { + t.Error("unexpected return value") + } + done <- struct{}{} + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Error("timout") + } +} + func TestStartPSS(t *testing.T) { pss := newMockPSScanner() l := newMockLogger()