diff options
author | Nick White <git@njw.name> | 2022-01-31 14:11:21 +0000 |
---|---|---|
committer | Nick White <git@njw.name> | 2022-01-31 14:11:21 +0000 |
commit | 550752fa2ab493fb6d10aa9d963fc45996c0d100 (patch) | |
tree | 279d2c7c7d062f6232f363d1462539738b7e4cc8 /internal/pipeline | |
parent | 57a3dc6da88e08951060e2e6e11605eb807f54ac (diff) |
Make pipeline context-aware, so the rescribe tool can cancel jobs
Diffstat (limited to 'internal/pipeline')
-rw-r--r-- | internal/pipeline/pipeline.go | 178 | ||||
-rw-r--r-- | internal/pipeline/put.go | 15 |
2 files changed, 171 insertions, 22 deletions
diff --git a/internal/pipeline/pipeline.go b/internal/pipeline/pipeline.go index d5e8e1c..b4a9d92 100644 --- a/internal/pipeline/pipeline.go +++ b/internal/pipeline/pipeline.go @@ -11,6 +11,7 @@ package pipeline import ( "bytes" + "context" "fmt" "io/ioutil" "log" @@ -129,8 +130,17 @@ func GetMailSettings() (mailSettings, error) { // dir, putting each successfully downloaded file name into the // process channel. If an error occurs it is sent to the errc channel // and the function returns early. -func download(dl chan string, process chan string, conn Downloader, dir string, errc chan error, logger *log.Logger) { +func download(ctx context.Context, dl chan string, process chan string, conn Downloader, dir string, errc chan error, logger *log.Logger) { for key := range dl { + select { + case <-ctx.Done(): + for range dl { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + close(process) + return + default: + } fn := filepath.Join(dir, filepath.Base(key)) logger.Println("Downloading", key) err := conn.Download(conn.WIPStorageId(), key, fn) @@ -151,8 +161,16 @@ func download(dl chan string, process chan string, conn Downloader, dir string, // once it has been successfully uploaded. The done channel is // then written to to signal completion. If an error occurs it // is sent to the errc channel and the function returns early. -func up(c chan string, done chan bool, conn Uploader, bookname string, errc chan error, logger *log.Logger) { +func up(ctx context.Context, c chan string, done chan bool, conn Uploader, bookname string, errc chan error, logger *log.Logger) { for path := range c { + select { + case <-ctx.Done(): + for range c { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } name := filepath.Base(path) key := bookname + "/" + name logger.Println("Uploading", key) @@ -181,8 +199,16 @@ func up(c chan string, done chan bool, conn Uploader, bookname string, errc chan // added to the toQueue once it has been uploaded. The done channel // is then written to to signal completion. If an error occurs it // is sent to the errc channel and the function returns early. -func upAndQueue(c chan string, done chan bool, toQueue string, conn UploadQueuer, bookname string, training string, errc chan error, logger *log.Logger) { +func upAndQueue(ctx context.Context, c chan string, done chan bool, toQueue string, conn UploadQueuer, bookname string, training string, errc chan error, logger *log.Logger) { for path := range c { + select { + case <-ctx.Done(): + for range c { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } name := filepath.Base(path) key := bookname + "/" + name logger.Println("Uploading", key) @@ -213,9 +239,17 @@ func upAndQueue(c chan string, done chan bool, toQueue string, conn UploadQueuer done <- true } -func Preprocess(thresholds []float64) func(chan string, chan string, chan error, *log.Logger) { - return func(pre chan string, up chan string, errc chan error, logger *log.Logger) { +func Preprocess(thresholds []float64) func(context.Context, chan string, chan string, chan error, *log.Logger) { + return func(ctx context.Context, pre chan string, up chan string, errc chan error, logger *log.Logger) { for path := range pre { + select { + case <-ctx.Done(): + for range pre { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } logger.Println("Preprocessing", path) done, err := preproc.PreProcMulti(path, thresholds, "binary", 0, true, 5, 30, 120, 30) if err != nil { @@ -233,8 +267,16 @@ func Preprocess(thresholds []float64) func(chan string, chan string, chan error, } } -func Wipe(towipe chan string, up chan string, errc chan error, logger *log.Logger) { +func Wipe(ctx context.Context, towipe chan string, up chan string, errc chan error, logger *log.Logger) { for path := range towipe { + select { + case <-ctx.Done(): + for range towipe { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } logger.Println("Wiping", path) s := strings.Split(path, ".") base := strings.Join(s[:len(s)-1], "") @@ -251,12 +293,20 @@ func Wipe(towipe chan string, up chan string, errc chan error, logger *log.Logge close(up) } -func Ocr(training string, tesscmd string) func(chan string, chan string, chan error, *log.Logger) { - return func(toocr chan string, up chan string, errc chan error, logger *log.Logger) { +func Ocr(training string, tesscmd string) func(context.Context, chan string, chan string, chan error, *log.Logger) { + return func(ctx context.Context, toocr chan string, up chan string, errc chan error, logger *log.Logger) { if tesscmd == "" { tesscmd = "tesseract" } for path := range toocr { + select { + case <-ctx.Done(): + for range toocr { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } logger.Println("OCRing", path) name := strings.Replace(path, ".png", "", 1) cmd := exec.Command(tesscmd, "-l", training, path, name, "-c", "tessedit_create_hocr=1", "-c", "hocr_font_info=0") @@ -276,13 +326,21 @@ func Ocr(training string, tesscmd string) func(chan string, chan string, chan er } } -func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Logger) { - return func(toanalyse chan string, up chan string, errc chan error, logger *log.Logger) { +func Analyse(conn Downloader) func(context.Context, chan string, chan string, chan error, *log.Logger) { + return func(ctx context.Context, toanalyse chan string, up chan string, errc chan error, logger *log.Logger) { confs := make(map[string][]*bookpipeline.Conf) bestconfs := make(map[string]*bookpipeline.Conf) savedir := "" for path := range toanalyse { + select { + case <-ctx.Done(): + for range toanalyse { + } // consume the rest of the receiving channel so it isn't blocked + errc <- ctx.Err() + return + default: + } if savedir == "" { savedir = filepath.Dir(path) } @@ -316,6 +374,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } defer f.Close() + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Finding best confidence for each page, and saving all confidences") for base, conf := range confs { var best float64 @@ -334,6 +399,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo f.Close() up <- fn + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Creating best file listing the best file for each page") fn = filepath.Join(savedir, "best") f, err = os.Create(fn) @@ -354,6 +426,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } sort.Strings(pgs) + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Downloading binarised and original images to create PDFs") bookname, err := filepath.Rel(os.TempDir(), savedir) if err != nil { @@ -374,6 +453,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } binhascontent, colourhascontent := false, false + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + var colourimgs, binimgs []pageimg for _, pg := range pgs { @@ -393,6 +479,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } for _, pg := range binimgs { + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Downloading binarised page to add to PDF", pg.img) err := conn.Download(conn.WIPStorageId(), bookname+"/"+pg.img, filepath.Join(savedir, pg.img)) if err != nil { @@ -412,6 +505,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } } + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + if binhascontent { fn = filepath.Join(savedir, bookname+".binarised.pdf") err = binarisedpdf.Save(fn) @@ -423,6 +523,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } for _, pg := range colourimgs { + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Downloading colour page to add to PDF", pg.img) colourfn := pg.img err = conn.Download(conn.WIPStorageId(), bookname+"/"+colourfn, filepath.Join(savedir, colourfn)) @@ -448,6 +555,14 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo } } } + + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + if colourhascontent { fn = filepath.Join(savedir, bookname+".colour.pdf") err = colourpdf.Save(fn) @@ -458,6 +573,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo up <- fn } + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + logger.Println("Creating graph") fn = filepath.Join(savedir, "graph.png") f, err = os.Create(fn) @@ -474,6 +596,14 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo errc <- fmt.Errorf("Error rendering graph: %s", err) return } + + select { + case <-ctx.Done(): + errc <- ctx.Err() + return + default: + } + if err == nil { up <- fn } @@ -546,7 +676,7 @@ func allOCRed(bookname string, conn Lister) bool { // OcrPage OCRs a page based on a message. It may make sense to // roll this back into processBook (on which it is based) once // working well. -func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, chan string, chan error, *log.Logger), fromQueue string, toQueue string) error { +func OcrPage(ctx context.Context, msg bookpipeline.Qmsg, conn Pipeliner, process func(context.Context, chan string, chan string, chan error, *log.Logger), fromQueue string, toQueue string) error { dl := make(chan string) msgc := make(chan bookpipeline.Qmsg) processc := make(chan string) @@ -570,19 +700,23 @@ func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, ch go heartbeat(conn, t, msg, fromQueue, msgc, errc) // these functions will do their jobs when their channels have data - go download(dl, processc, conn, d, errc, conn.GetLogger()) - go process(processc, upc, errc, conn.GetLogger()) - go up(upc, done, conn, bookname, errc, conn.GetLogger()) + go download(ctx, dl, processc, conn, d, errc, conn.GetLogger()) + go process(ctx, processc, upc, errc, conn.GetLogger()) + go up(ctx, upc, done, conn, bookname, errc, conn.GetLogger()) dl <- msgparts[0] close(dl) - // wait for either the done or errc channel to be sent to + // wait for either the done or errc channels to be sent to select { case err = <-errc: t.Stop() _ = os.RemoveAll(d) return err + case <-ctx.Done(): + t.Stop() + _ = os.RemoveAll(d) + return ctx.Err() case <-done: } @@ -624,7 +758,7 @@ func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, ch return nil } -func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, chan string, chan error, *log.Logger), match *regexp.Regexp, fromQueue string, toQueue string) error { +func ProcessBook(ctx context.Context, msg bookpipeline.Qmsg, conn Pipeliner, process func(context.Context, chan string, chan string, chan error, *log.Logger), match *regexp.Regexp, fromQueue string, toQueue string) error { dl := make(chan string) msgc := make(chan bookpipeline.Qmsg) processc := make(chan string) @@ -650,12 +784,12 @@ func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string go heartbeat(conn, t, msg, fromQueue, msgc, errc) // these functions will do their jobs when their channels have data - go download(dl, processc, conn, d, errc, conn.GetLogger()) - go process(processc, upc, errc, conn.GetLogger()) + go download(ctx, dl, processc, conn, d, errc, conn.GetLogger()) + go process(ctx, processc, upc, errc, conn.GetLogger()) if toQueue == conn.OCRPageQueueId() { - go upAndQueue(upc, done, toQueue, conn, bookname, training, errc, conn.GetLogger()) + go upAndQueue(ctx, upc, done, toQueue, conn, bookname, training, errc, conn.GetLogger()) } else { - go up(upc, done, conn, bookname, errc, conn.GetLogger()) + go up(ctx, upc, done, conn, bookname, errc, conn.GetLogger()) } conn.Log("Getting list of objects to download") @@ -716,6 +850,10 @@ func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string } } return err + case <-ctx.Done(): + t.Stop() + _ = os.RemoveAll(d) + return ctx.Err() case <-done: } diff --git a/internal/pipeline/put.go b/internal/pipeline/put.go index d44f74f..68ad70e 100644 --- a/internal/pipeline/put.go +++ b/internal/pipeline/put.go @@ -5,6 +5,7 @@ package pipeline import ( + "context" "fmt" "image" _ "image/jpeg" @@ -43,7 +44,7 @@ func (f fileWalk) Walk(path string, info os.FileInfo, err error) error { // CheckImages checks that all files with a ".jpg" or ".png" suffix // in a directory are images that can be decoded (skipping dotfiles) -func CheckImages(dir string) error { +func CheckImages(ctx context.Context, dir string) error { checker := make(fileWalk) go func() { _ = filepath.Walk(dir, checker.Walk) @@ -51,6 +52,11 @@ func CheckImages(dir string) error { }() for path := range checker { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } suffix := filepath.Ext(path) lsuffix := strings.ToLower(suffix) if lsuffix != ".jpg" && lsuffix != ".png" { @@ -89,7 +95,7 @@ func DetectQueueType(dir string, conn Queuer) string { // slash. It also appends all file names with sequential numbers, like // 0001, to ensure they are appropriately named for further processing // in the pipeline. -func UploadImages(dir string, bookname string, conn Uploader) error { +func UploadImages(ctx context.Context, dir string, bookname string, conn Uploader) error { files, err := ioutil.ReadDir(dir) if err != nil { fmt.Errorf("Failed to read directory %s: %v", dir, err) @@ -97,6 +103,11 @@ func UploadImages(dir string, bookname string, conn Uploader) error { filenum := 0 for _, file := range files { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } if file.IsDir() { continue } |