diff options
author | Nick White <git@njw.name> | 2022-01-31 14:11:21 +0000 |
---|---|---|
committer | Nick White <git@njw.name> | 2022-01-31 14:11:21 +0000 |
commit | 550752fa2ab493fb6d10aa9d963fc45996c0d100 (patch) | |
tree | 279d2c7c7d062f6232f363d1462539738b7e4cc8 | |
parent | 57a3dc6da88e08951060e2e6e11605eb807f54ac (diff) |
Make pipeline context-aware, so the rescribe tool can cancel jobs
-rw-r--r-- | cmd/bookpipeline/main.go | 11 | ||||
-rw-r--r-- | cmd/booktopipeline/main.go | 7 | ||||
-rw-r--r-- | cmd/rescribe/gui.go | 37 | ||||
-rw-r--r-- | cmd/rescribe/main.go | 31 | ||||
-rw-r--r-- | internal/pipeline/pipeline.go | 178 | ||||
-rw-r--r-- | internal/pipeline/put.go | 15 |
6 files changed, 234 insertions, 45 deletions
diff --git a/cmd/bookpipeline/main.go b/cmd/bookpipeline/main.go index 65c9b79..4de9ea9 100644 --- a/cmd/bookpipeline/main.go +++ b/cmd/bookpipeline/main.go @@ -9,6 +9,7 @@ package main import ( "bytes" + "context" "flag" "fmt" "log" @@ -118,6 +119,8 @@ func main() { wipePattern := regexp.MustCompile(`[0-9]{4,6}(.bin)?.png$`) ocredPattern := regexp.MustCompile(`.hocr$`) + var ctx context.Context + var conn Pipeliner switch *conntype { case "aws": @@ -190,7 +193,7 @@ func main() { } conn.Log("Message received on preprocess queue, processing", msg.Body) stopTimer(stopIfQuiet) - err = pipeline.ProcessBook(msg, conn, pipeline.Preprocess([]float64{0.1, 0.2, 0.4, 0.5}), origPattern, conn.PreQueueId(), conn.OCRPageQueueId()) + err = pipeline.ProcessBook(ctx, msg, conn, pipeline.Preprocess([]float64{0.1, 0.2, 0.4, 0.5}), origPattern, conn.PreQueueId(), conn.OCRPageQueueId()) resetTimer(stopIfQuiet, quietTime) if err != nil { conn.Log("Error during preprocess", err) @@ -208,7 +211,7 @@ func main() { } stopTimer(stopIfQuiet) conn.Log("Message received on wipeonly queue, processing", msg.Body) - err = pipeline.ProcessBook(msg, conn, pipeline.Wipe, wipePattern, conn.WipeQueueId(), conn.OCRPageQueueId()) + err = pipeline.ProcessBook(ctx, msg, conn, pipeline.Wipe, wipePattern, conn.WipeQueueId(), conn.OCRPageQueueId()) resetTimer(stopIfQuiet, quietTime) if err != nil { conn.Log("Error during wipe", err) @@ -228,7 +231,7 @@ func main() { checkOCRPageQueue = time.After(0) stopTimer(stopIfQuiet) conn.Log("Message received on OCR Page queue, processing", msg.Body) - err = pipeline.OcrPage(msg, conn, pipeline.Ocr(*training, ""), conn.OCRPageQueueId(), conn.AnalyseQueueId()) + err = pipeline.OcrPage(ctx, msg, conn, pipeline.Ocr(*training, ""), conn.OCRPageQueueId(), conn.AnalyseQueueId()) resetTimer(stopIfQuiet, quietTime) if err != nil { conn.Log("Error during OCR Page process", err) @@ -246,7 +249,7 @@ func main() { } stopTimer(stopIfQuiet) conn.Log("Message received on analyse queue, processing", msg.Body) - err = pipeline.ProcessBook(msg, conn, pipeline.Analyse(conn), ocredPattern, conn.AnalyseQueueId(), "") + err = pipeline.ProcessBook(ctx, msg, conn, pipeline.Analyse(conn), ocredPattern, conn.AnalyseQueueId(), "") resetTimer(stopIfQuiet, quietTime) if err != nil { conn.Log("Error during analysis", err) diff --git a/cmd/booktopipeline/main.go b/cmd/booktopipeline/main.go index b4f4d99..bf088a0 100644 --- a/cmd/booktopipeline/main.go +++ b/cmd/booktopipeline/main.go @@ -7,6 +7,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -65,6 +66,8 @@ func main() { bookname = filepath.Base(bookdir) } + var ctx context.Context + if *verbose { verboselog = log.New(os.Stdout, "", log.LstdFlags) } else { @@ -97,7 +100,7 @@ func main() { } verboselog.Println("Checking that all images are valid in", bookdir) - err = pipeline.CheckImages(bookdir) + err = pipeline.CheckImages(ctx, bookdir) if err != nil { log.Fatalln(err) } @@ -112,7 +115,7 @@ func main() { } verboselog.Println("Uploading all images are valid in", bookdir) - err = pipeline.UploadImages(bookdir, bookname, conn) + err = pipeline.UploadImages(ctx, bookdir, bookname, conn) if err != nil { log.Fatalln(err) } diff --git a/cmd/rescribe/gui.go b/cmd/rescribe/gui.go index 06e6ddd..bdcc16c 100644 --- a/cmd/rescribe/gui.go +++ b/cmd/rescribe/gui.go @@ -6,6 +6,7 @@ package main import ( "bufio" + "context" "errors" "fmt" "io" @@ -221,7 +222,7 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error myWindow.Resize(fyne.NewSize(800, 400)) - var gobtn *widget.Button + var abortbtn, gobtn *widget.Button var fullContent *fyne.Container dir := widget.NewLabel("") @@ -272,6 +273,23 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error detail := widget.NewAccordion(widget.NewAccordionItem("Log", logarea)) + var ctx context.Context + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(context.Background()) + + abortbtn = widget.NewButtonWithIcon("Abort", theme.CancelIcon(), func() { + fmt.Printf("\nAbort\n") + cancel() + progressBar.SetValue(0.0) + gobtn.SetText("Process OCR") + for _, v := range []fyne.Disableable{folderBtn, pdfBtn, gbookBtn, trainingOpts, gobtn} { + v.Enable() + } + abortbtn.Disable() + ctx, cancel = context.WithCancel(context.Background()) + }) + abortbtn.Disable() + gobtn = widget.NewButtonWithIcon("Start OCR", theme.UploadIcon(), func() { if dir.Text == "" { return @@ -347,6 +365,7 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error for _, v := range []fyne.Disableable{folderBtn, pdfBtn, gbookBtn, trainingOpts, gobtn} { v.Enable() } + abortbtn.Disable() return } @@ -356,6 +375,8 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error v.Disable() } + abortbtn.Enable() + progressBar.SetValue(0.1) if strings.HasSuffix(dir.Text, ".pdf") && !f.IsDir() { @@ -370,6 +391,7 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error for _, v := range []fyne.Disableable{folderBtn, pdfBtn, gbookBtn, trainingOpts, gobtn} { v.Enable() } + abortbtn.Disable() return } @@ -385,6 +407,7 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error for _, v := range []fyne.Disableable{folderBtn, pdfBtn, gbookBtn, trainingOpts, gobtn} { v.Enable() } + abortbtn.Disable() return } @@ -399,7 +422,11 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error training = training[start:end] } - err = startProcess(log, cmd, bookdir, bookname, training, savedir, tessdir) + err = startProcess(ctx, log, cmd, bookdir, bookname, training, savedir, tessdir) + if strings.HasSuffix(err.Error(), "context canceled") { + progressBar.SetValue(0.0) + return + } if err != nil { msg := fmt.Sprintf("Error during processing: %v\n", err) dialog.ShowError(errors.New(msg), myWindow) @@ -410,6 +437,7 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error for _, v := range []fyne.Disableable{folderBtn, pdfBtn, gbookBtn, trainingOpts, gobtn} { v.Enable() } + abortbtn.Disable() return } @@ -419,6 +447,7 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error for _, v := range []fyne.Disableable{folderBtn, pdfBtn, gbookBtn, trainingOpts, gobtn} { v.Enable() } + abortbtn.Disable() }() }) gobtn.Disable() @@ -429,8 +458,8 @@ func startGui(log log.Logger, cmd string, training string, tessdir string) error trainingBits := container.New(layout.NewBorderLayout(nil, nil, trainingLabel, nil), trainingLabel, trainingOpts) - fullContent = container.NewVBox(choices, chosen, trainingBits, gobtn, progressBar, detail) - startContent := container.NewVBox(choices, trainingBits, gobtn, progressBar, detail) + fullContent = container.NewVBox(choices, chosen, trainingBits, gobtn, abortbtn, progressBar, detail) + startContent := container.NewVBox(choices, trainingBits, gobtn, abortbtn, progressBar, detail) myWindow.SetContent(startContent) diff --git a/cmd/rescribe/main.go b/cmd/rescribe/main.go index 3f7bd71..cd242af 100644 --- a/cmd/rescribe/main.go +++ b/cmd/rescribe/main.go @@ -12,6 +12,7 @@ package main import ( "archive/zip" "bytes" + "context" _ "embed" "flag" "fmt" @@ -284,7 +285,9 @@ These training files are included in rescribe, and are always available: ispdf = true } - err = startProcess(*verboselog, tessCommand, bookdir, bookname, trainingName, savedir, tessdir) + var ctx context.Context + + err = startProcess(ctx, *verboselog, tessCommand, bookdir, bookname, trainingName, savedir, tessdir) if err != nil { log.Fatalln(err) } @@ -413,7 +416,7 @@ func rmIfNotImage(f string) error { return nil } -func startProcess(logger log.Logger, tessCommand string, bookdir string, bookname string, trainingName string, savedir string, tessdir string) error { +func startProcess(ctx context.Context, logger log.Logger, tessCommand string, bookdir string, bookname string, trainingName string, savedir string, tessdir string) error { _, err := exec.Command(tessCommand, "--help").Output() if err != nil { errmsg := "Error, Can't run Tesseract\n" @@ -441,14 +444,14 @@ func startProcess(logger log.Logger, tessCommand string, bookdir string, booknam fmt.Printf("Copying book to pipeline\n") - err = uploadbook(bookdir, bookname, conn) + err = uploadbook(ctx, bookdir, bookname, conn) if err != nil { _ = os.RemoveAll(tempdir) return fmt.Errorf("Error uploading book: %v", err) } fmt.Printf("Processing book\n") - err = processbook(trainingName, tessCommand, conn) + err = processbook(ctx, trainingName, tessCommand, conn) if err != nil { _ = os.RemoveAll(tempdir) return fmt.Errorf("Error processing book: %v", err) @@ -554,16 +557,16 @@ func addTxtVersion(hocrfn string) error { return nil } -func uploadbook(dir string, name string, conn Pipeliner) error { +func uploadbook(ctx context.Context, dir string, name string, conn Pipeliner) error { _, err := os.Stat(dir) if err != nil && !os.IsExist(err) { return fmt.Errorf("Error: directory %s not found", dir) } - err = pipeline.CheckImages(dir) + err = pipeline.CheckImages(ctx, dir) if err != nil { return fmt.Errorf("Error with images in %s: %v", dir, err) } - err = pipeline.UploadImages(dir, name, conn) + err = pipeline.UploadImages(ctx, dir, name, conn) if err != nil { return fmt.Errorf("Error saving images to process from %s: %v", dir, err) } @@ -602,7 +605,7 @@ func downloadbook(dir string, name string, conn Pipeliner) error { return nil } -func processbook(training string, tesscmd string, conn Pipeliner) error { +func processbook(ctx context.Context, training string, tesscmd string, conn Pipeliner) error { origPattern := regexp.MustCompile(`[0-9]{4}.(jpg|png)$`) wipePattern := regexp.MustCompile(`[0-9]{4,6}(.bin)?.(jpg|png)$`) ocredPattern := regexp.MustCompile(`.hocr$`) @@ -624,6 +627,8 @@ func processbook(training string, tesscmd string, conn Pipeliner) error { for { select { + case <-ctx.Done(): + return ctx.Err() case <-checkPreQueue: msg, err := conn.CheckQueue(conn.PreQueueId(), QueueTimeoutSecs) checkPreQueue = time.After(PauseBetweenChecks) @@ -637,12 +642,12 @@ func processbook(training string, tesscmd string, conn Pipeliner) error { stopTimer(stopIfQuiet) conn.Log("Message received on preprocess queue, processing", msg.Body) fmt.Printf(" Preprocessing book (binarising and wiping)\n") - err = pipeline.ProcessBook(msg, conn, pipeline.Preprocess(thresholds), origPattern, conn.PreQueueId(), conn.OCRPageQueueId()) - fmt.Printf(" OCRing pages ") // this is expected to be added to with dots by OCRPage output + err = pipeline.ProcessBook(ctx, msg, conn, pipeline.Preprocess(thresholds), origPattern, conn.PreQueueId(), conn.OCRPageQueueId()) resetTimer(stopIfQuiet, quietTime) if err != nil { return fmt.Errorf("Error during preprocess: %v", err) } + fmt.Printf(" OCRing pages ") // this is expected to be added to with dots by OCRPage output case <-checkWipeQueue: msg, err := conn.CheckQueue(conn.WipeQueueId(), QueueTimeoutSecs) checkWipeQueue = time.After(PauseBetweenChecks) @@ -656,7 +661,7 @@ func processbook(training string, tesscmd string, conn Pipeliner) error { stopTimer(stopIfQuiet) conn.Log("Message received on wipeonly queue, processing", msg.Body) fmt.Printf(" Preprocessing book (wiping only)\n") - err = pipeline.ProcessBook(msg, conn, pipeline.Wipe, wipePattern, conn.WipeQueueId(), conn.OCRPageQueueId()) + err = pipeline.ProcessBook(ctx, msg, conn, pipeline.Wipe, wipePattern, conn.WipeQueueId(), conn.OCRPageQueueId()) fmt.Printf(" OCRing pages ") // this is expected to be added to with dots by OCRPage output resetTimer(stopIfQuiet, quietTime) if err != nil { @@ -677,7 +682,7 @@ func processbook(training string, tesscmd string, conn Pipeliner) error { stopTimer(stopIfQuiet) conn.Log("Message received on OCR Page queue, processing", msg.Body) fmt.Printf(".") - err = pipeline.OcrPage(msg, conn, pipeline.Ocr(training, tesscmd), conn.OCRPageQueueId(), conn.AnalyseQueueId()) + err = pipeline.OcrPage(ctx, msg, conn, pipeline.Ocr(training, tesscmd), conn.OCRPageQueueId(), conn.AnalyseQueueId()) resetTimer(stopIfQuiet, quietTime) if err != nil { return fmt.Errorf("\nError during OCR Page process: %v", err) @@ -695,7 +700,7 @@ func processbook(training string, tesscmd string, conn Pipeliner) error { stopTimer(stopIfQuiet) conn.Log("Message received on analyse queue, processing", msg.Body) fmt.Printf("\n Analysing OCR and compiling PDFs\n") - err = pipeline.ProcessBook(msg, conn, pipeline.Analyse(conn), ocredPattern, conn.AnalyseQueueId(), "") + err = pipeline.ProcessBook(ctx, msg, conn, pipeline.Analyse(conn), ocredPattern, conn.AnalyseQueueId(), "") resetTimer(stopIfQuiet, quietTime) if err != nil { return fmt.Errorf("Error during analysis: %v", err) diff --git a/internal/pipeline/pipeline.go b/internal/pipeline/pipeline.go index d5e8e1c..b4a9d92 100644 --- a/internal/pipeline/pipeline.go +++ b/internal/pipeline/pipeline.go @@ -11,6 +11,7 @@ package pipeline import ( "bytes" + "context" "fmt" "io/ioutil" "log" @@ -129,8 +130,17 @@ func GetMailSettings() (mailSettings, error) { // dir, putting each successfully downloaded file name into the // process channel. If an error occurs it is sent to the errc channel // and the function returns early. -func download(dl chan string, process chan string, conn Downloader, dir string, errc chan error, logger *log.Logger) { +func download(ctx context.Context, dl chan string, process chan string, conn Downloader, dir string, errc chan error, logger *log.Logger) { for key := range dl { + select { + case <-ctx.Done(): + for range dl { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + close(process) + return + default: + } fn := filepath.Join(dir, filepath.Base(key)) logger.Println("Downloading", key) err := conn.Download(conn.WIPStorageId(), key, fn) @@ -151,8 +161,16 @@ func download(dl chan string, process chan string, conn Downloader, dir string, // once it has been successfully uploaded. The done channel is // then written to to signal completion. If an error occurs it // is sent to the errc channel and the function returns early. -func up(c chan string, done chan bool, conn Uploader, bookname string, errc chan error, logger *log.Logger) { +func up(ctx context.Context, c chan string, done chan bool, conn Uploader, bookname string, errc chan error, logger *log.Logger) { for path := range c { + select { + case <-ctx.Done(): + for range c { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } name := filepath.Base(path) key := bookname + "/" + name logger.Println("Uploading", key) @@ -181,8 +199,16 @@ func up(c chan string, done chan bool, conn Uploader, bookname string, errc chan // added to the toQueue once it has been uploaded. The done channel // is then written to to signal completion. If an error occurs it // is sent to the errc channel and the function returns early. -func upAndQueue(c chan string, done chan bool, toQueue string, conn UploadQueuer, bookname string, training string, errc chan error, logger *log.Logger) { +func upAndQueue(ctx context.Context, c chan string, done chan bool, toQueue string, conn UploadQueuer, bookname string, training string, errc chan error, logger *log.Logger) { for path := range c { + select { + case <-ctx.Done(): + for range c { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } name := filepath.Base(path) key := bookname + "/" + name logger.Println("Uploading", key) @@ -213,9 +239,17 @@ func upAndQueue(c chan string, done chan bool, toQueue string, conn UploadQueuer done <- true } -func Preprocess(thresholds []float64) func(chan string, chan string, chan error, *log.Logger) { - return func(pre chan string, up chan string, errc chan error, logger *log.Logger) { +func Preprocess(thresholds []float64) func(context.Context, chan string, chan string, chan error, *log.Logger) { + return func(ctx context.Context, pre chan string, up chan string, errc chan error, logger *log.Logger) { for path := range pre { + select { + case <-ctx.Done(): + for range pre { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } logger.Println("Preprocessing", path) done, err := preproc.PreProcMulti(path, thresholds, "binary", 0, true, 5, 30, 120, 30) if err != nil { @@ -233,8 +267,16 @@ func Preprocess(thresholds []float64) func(chan string, chan string, chan error, } } -func Wipe(towipe chan string, up chan string, errc chan error, logger *log.Logger) { +func Wipe(ctx context.Context, towipe chan string, up chan string, errc chan error, logger *log.Logger) { for path := range towipe { + select { + case <-ctx.Done(): + for range towipe { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } logger.Println("Wiping", path) s := strings.Split(path, ".") base := strings.Join(s[:len(s)-1], "") @@ -251,12 +293,20 @@ func Wipe(towipe chan string, up chan string, errc chan error, logger *log.Logge close(up) } -func Ocr(training string, tesscmd string) func(chan string, chan string, chan error, *log.Logger) { - return func(toocr chan string, up chan string, errc chan error, logger *log.Logger) { +func Ocr(training string, tesscmd string) func(context.Context, chan string, chan string, chan error, *log.Logger) { + return func(ctx context.Context, toocr chan string, up chan string, errc chan error, logger *log.Logger) { if tesscmd == "" { tesscmd = "tesseract" } for path := range toocr { + select { + case <-ctx.Done(): + for range toocr { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } logger.Println("OCRing", path) name := strings.Replace(path, ".png", "", 1) cmd := exec.Command(tesscmd, "-l", training, path, name, "-c", "tessedit_create_hocr=1", "-c", "hocr_font_info=0") @@ -276,13 +326,21 @@ func Ocr(training string, tesscmd string) func(chan string, chan string, chan er } } -func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Logger) { - return func(toanalyse chan string, up chan string, errc chan error, logger *log.Logger) { +func Analyse(conn Downloader) func(context.Context, chan string, chan string, chan error, *log.Logger) { + return func(ctx context.Context, toanalyse chan string, up chan string, errc chan error, logger *log.Logger) { confs := make(map[string][]*bookpipeline.Conf) bestconfs := make(map[string]*bookpipeline.Conf) savedir := "" for path := range toanalyse { + select { + case <-ctx.Done(): + for range toanalyse { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } if savedir == "" { savedir = filepath.Dir(path) } @@ -316,6 +374,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } defer f.Close() + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Finding best confidence for each page, and saving all confidences") for base, conf := range confs { var best float64 @@ -334,6 +399,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo f.Close() up <- fn + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Creating best file listing the best file for each page") fn = filepath.Join(savedir, "best") f, err = os.Create(fn) @@ -354,6 +426,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } sort.Strings(pgs) + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Downloading binarised and original images to create PDFs") bookname, err := filepath.Rel(os.TempDir(), savedir) if err != nil { @@ -374,6 +453,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } binhascontent, colourhascontent := false, false + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + var colourimgs, binimgs []pageimg for _, pg := range pgs { @@ -393,6 +479,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } for _, pg := range binimgs { + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Downloading binarised page to add to PDF", pg.img) err := conn.Download(conn.WIPStorageId(), bookname+"/"+pg.img, filepath.Join(savedir, pg.img)) if err != nil { @@ -412,6 +505,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } } + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + if binhascontent { fn = filepath.Join(savedir, bookname+".binarised.pdf") err = binarisedpdf.Save(fn) @@ -423,6 +523,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } for _, pg := range colourimgs { + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Downloading colour page to add to PDF", pg.img) colourfn := pg.img err = conn.Download(conn.WIPStorageId(), bookname+"/"+colourfn, filepath.Join(savedir, colourfn)) @@ -448,6 +555,14 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } } } + + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + if colourhascontent { fn = filepath.Join(savedir, bookname+".colour.pdf") err = colourpdf.Save(fn) @@ -458,6 +573,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo up <- fn } + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Creating graph") fn = filepath.Join(savedir, "graph.png") f, err = os.Create(fn) @@ -474,6 +596,14 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo errc <- fmt.Errorf("Error rendering graph: %s", err) return } + + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + if err == nil { up <- fn } @@ -546,7 +676,7 @@ func allOCRed(bookname string, conn Lister) bool { // OcrPage OCRs a page based on a message. It may make sense to // roll this back into processBook (on which it is based) once // working well. -func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, chan string, chan error, *log.Logger), fromQueue string, toQueue string) error { +func OcrPage(ctx context.Context, msg bookpipeline.Qmsg, conn Pipeliner, process func(context.Context, chan string, chan string, chan error, *log.Logger), fromQueue string, toQueue string) error { dl := make(chan string) msgc := make(chan bookpipeline.Qmsg) processc := make(chan string) @@ -570,19 +700,23 @@ func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, ch go heartbeat(conn, t, msg, fromQueue, msgc, errc) // these functions will do their jobs when their channels have data - go download(dl, processc, conn, d, errc, conn.GetLogger()) - go process(processc, upc, errc, conn.GetLogger()) - go up(upc, done, conn, bookname, errc, conn.GetLogger()) + go download(ctx, dl, processc, conn, d, errc, conn.GetLogger()) + go process(ctx, processc, upc, errc, conn.GetLogger()) + go up(ctx, upc, done, conn, bookname, errc, conn.GetLogger()) dl <- msgparts[0] close(dl) - // wait for either the done or errc channel to be sent to + // wait for either the done or errc channels to be sent to select { case err = <-errc: t.Stop() _ = os.RemoveAll(d) return err + case <-ctx.Done(): + t.Stop() + _ = os.RemoveAll(d) + return ctx.Err() case <-done: } @@ -624,7 +758,7 @@ func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, ch return nil } -func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, chan string, chan error, *log.Logger), match *regexp.Regexp, fromQueue string, toQueue string) error { +func ProcessBook(ctx context.Context, msg bookpipeline.Qmsg, conn Pipeliner, process func(context.Context, chan string, chan string, chan error, *log.Logger), match *regexp.Regexp, fromQueue string, toQueue string) error { dl := make(chan string) msgc := make(chan bookpipeline.Qmsg) processc := make(chan string) @@ -650,12 +784,12 @@ func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string go heartbeat(conn, t, msg, fromQueue, msgc, errc) // these functions will do their jobs when their channels have data - go download(dl, processc, conn, d, errc, conn.GetLogger()) - go process(processc, upc, errc, conn.GetLogger()) + go download(ctx, dl, processc, conn, d, errc, conn.GetLogger()) + go process(ctx, processc, upc, errc, conn.GetLogger()) if toQueue == conn.OCRPageQueueId() { - go upAndQueue(upc, done, toQueue, conn, bookname, training, errc, conn.GetLogger()) + go upAndQueue(ctx, upc, done, toQueue, conn, bookname, training, errc, conn.GetLogger()) } else { - go up(upc, done, conn, bookname, errc, conn.GetLogger()) + go up(ctx, upc, done, conn, bookname, errc, conn.GetLogger()) } conn.Log("Getting list of objects to download") @@ -716,6 +850,10 @@ func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string } } return err + case <-ctx.Done(): + t.Stop() + _ = os.RemoveAll(d) + return ctx.Err() case <-done: } diff --git a/internal/pipeline/put.go b/internal/pipeline/put.go index d44f74f..68ad70e 100644 --- a/internal/pipeline/put.go +++ b/internal/pipeline/put.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "fmt" "image" _ "image/jpeg" @@ -43,7 +44,7 @@ func (f fileWalk) Walk(path string, info os.FileInfo, err error) error { // CheckImages checks that all files with a ".jpg" or ".png" suffix // in a directory are images that can be decoded (skipping dotfiles) -func CheckImages(dir string) error { +func CheckImages(ctx context.Context, dir string) error { checker := make(fileWalk) go func() { _ = filepath.Walk(dir, checker.Walk) @@ -51,6 +52,11 @@ func CheckImages(dir string) error { }() for path := range checker { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } suffix := filepath.Ext(path) lsuffix := strings.ToLower(suffix) if lsuffix != ".jpg" && lsuffix != ".png" { @@ -89,7 +95,7 @@ func DetectQueueType(dir string, conn Queuer) string { // slash. It also appends all file names with sequential numbers, like // 0001, to ensure they are appropriately named for further processing // in the pipeline. -func UploadImages(dir string, bookname string, conn Uploader) error { +func UploadImages(ctx context.Context, dir string, bookname string, conn Uploader) error { files, err := ioutil.ReadDir(dir) if err != nil { fmt.Errorf("Failed to read directory %s: %v", dir, err) @@ -97,6 +103,11 @@ func UploadImages(dir string, bookname string, conn Uploader) error { filenum := 0 for _, file := range files { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } if file.IsDir() { continue } |