summaryrefslogtreecommitdiff
path: root/internal/pipeline/pipeline.go
diff options
context:
space:
mode:
authorNick White <git@njw.name>2022-01-31 14:11:21 +0000
committerNick White <git@njw.name>2022-01-31 14:11:21 +0000
commit550752fa2ab493fb6d10aa9d963fc45996c0d100 (patch)
tree279d2c7c7d062f6232f363d1462539738b7e4cc8 /internal/pipeline/pipeline.go
parent57a3dc6da88e08951060e2e6e11605eb807f54ac (diff)
Make pipeline context-aware, so the rescribe tool can cancel jobs
Diffstat (limited to 'internal/pipeline/pipeline.go')
-rw-r--r--internal/pipeline/pipeline.go178
1 files changed, 158 insertions, 20 deletions
diff --git a/internal/pipeline/pipeline.go b/internal/pipeline/pipeline.go
index d5e8e1c..b4a9d92 100644
--- a/internal/pipeline/pipeline.go
+++ b/internal/pipeline/pipeline.go
@@ -11,6 +11,7 @@ package pipeline
import (
"bytes"
+ "context"
"fmt"
"io/ioutil"
"log"
@@ -129,8 +130,17 @@ func GetMailSettings() (mailSettings, error) {
// dir, putting each successfully downloaded file name into the
// process channel. If an error occurs it is sent to the errc channel
// and the function returns early.
-func download(dl chan string, process chan string, conn Downloader, dir string, errc chan error, logger *log.Logger) {
+func download(ctx context.Context, dl chan string, process chan string, conn Downloader, dir string, errc chan error, logger *log.Logger) {
for key := range dl {
+ select {
+ case <-ctx.Done():
+ for range dl {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ close(process)
+ return
+ default:
+ }
fn := filepath.Join(dir, filepath.Base(key))
logger.Println("Downloading", key)
err := conn.Download(conn.WIPStorageId(), key, fn)
@@ -151,8 +161,16 @@ func download(dl chan string, process chan string, conn Downloader, dir string,
// once it has been successfully uploaded. The done channel is
// then written to to signal completion. If an error occurs it
// is sent to the errc channel and the function returns early.
-func up(c chan string, done chan bool, conn Uploader, bookname string, errc chan error, logger *log.Logger) {
+func up(ctx context.Context, c chan string, done chan bool, conn Uploader, bookname string, errc chan error, logger *log.Logger) {
for path := range c {
+ select {
+ case <-ctx.Done():
+ for range c {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ return
+ default:
+ }
name := filepath.Base(path)
key := bookname + "/" + name
logger.Println("Uploading", key)
@@ -181,8 +199,16 @@ func up(c chan string, done chan bool, conn Uploader, bookname string, errc chan
// added to the toQueue once it has been uploaded. The done channel
// is then written to to signal completion. If an error occurs it
// is sent to the errc channel and the function returns early.
-func upAndQueue(c chan string, done chan bool, toQueue string, conn UploadQueuer, bookname string, training string, errc chan error, logger *log.Logger) {
+func upAndQueue(ctx context.Context, c chan string, done chan bool, toQueue string, conn UploadQueuer, bookname string, training string, errc chan error, logger *log.Logger) {
for path := range c {
+ select {
+ case <-ctx.Done():
+ for range c {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ return
+ default:
+ }
name := filepath.Base(path)
key := bookname + "/" + name
logger.Println("Uploading", key)
@@ -213,9 +239,17 @@ func upAndQueue(c chan string, done chan bool, toQueue string, conn UploadQueuer
done <- true
}
-func Preprocess(thresholds []float64) func(chan string, chan string, chan error, *log.Logger) {
- return func(pre chan string, up chan string, errc chan error, logger *log.Logger) {
+func Preprocess(thresholds []float64) func(context.Context, chan string, chan string, chan error, *log.Logger) {
+ return func(ctx context.Context, pre chan string, up chan string, errc chan error, logger *log.Logger) {
for path := range pre {
+ select {
+ case <-ctx.Done():
+ for range pre {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ return
+ default:
+ }
logger.Println("Preprocessing", path)
done, err := preproc.PreProcMulti(path, thresholds, "binary", 0, true, 5, 30, 120, 30)
if err != nil {
@@ -233,8 +267,16 @@ func Preprocess(thresholds []float64) func(chan string, chan string, chan error,
}
}
-func Wipe(towipe chan string, up chan string, errc chan error, logger *log.Logger) {
+func Wipe(ctx context.Context, towipe chan string, up chan string, errc chan error, logger *log.Logger) {
for path := range towipe {
+ select {
+ case <-ctx.Done():
+ for range towipe {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ return
+ default:
+ }
logger.Println("Wiping", path)
s := strings.Split(path, ".")
base := strings.Join(s[:len(s)-1], "")
@@ -251,12 +293,20 @@ func Wipe(towipe chan string, up chan string, errc chan error, logger *log.Logge
close(up)
}
-func Ocr(training string, tesscmd string) func(chan string, chan string, chan error, *log.Logger) {
- return func(toocr chan string, up chan string, errc chan error, logger *log.Logger) {
+func Ocr(training string, tesscmd string) func(context.Context, chan string, chan string, chan error, *log.Logger) {
+ return func(ctx context.Context, toocr chan string, up chan string, errc chan error, logger *log.Logger) {
if tesscmd == "" {
tesscmd = "tesseract"
}
for path := range toocr {
+ select {
+ case <-ctx.Done():
+ for range toocr {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ return
+ default:
+ }
logger.Println("OCRing", path)
name := strings.Replace(path, ".png", "", 1)
cmd := exec.Command(tesscmd, "-l", training, path, name, "-c", "tessedit_create_hocr=1", "-c", "hocr_font_info=0")
@@ -276,13 +326,21 @@ func Ocr(training string, tesscmd string) func(chan string, chan string, chan er
}
}
-func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Logger) {
- return func(toanalyse chan string, up chan string, errc chan error, logger *log.Logger) {
+func Analyse(conn Downloader) func(context.Context, chan string, chan string, chan error, *log.Logger) {
+ return func(ctx context.Context, toanalyse chan string, up chan string, errc chan error, logger *log.Logger) {
confs := make(map[string][]*bookpipeline.Conf)
bestconfs := make(map[string]*bookpipeline.Conf)
savedir := ""
for path := range toanalyse {
+ select {
+ case <-ctx.Done():
+ for range toanalyse {
+ } // consume the rest of the receiving channel so it isn't blocked
+ errc <- ctx.Err()
+ return
+ default:
+ }
if savedir == "" {
savedir = filepath.Dir(path)
}
@@ -316,6 +374,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
defer f.Close()
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
logger.Println("Finding best confidence for each page, and saving all confidences")
for base, conf := range confs {
var best float64
@@ -334,6 +399,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
f.Close()
up <- fn
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
logger.Println("Creating best file listing the best file for each page")
fn = filepath.Join(savedir, "best")
f, err = os.Create(fn)
@@ -354,6 +426,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
sort.Strings(pgs)
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
logger.Println("Downloading binarised and original images to create PDFs")
bookname, err := filepath.Rel(os.TempDir(), savedir)
if err != nil {
@@ -374,6 +453,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
binhascontent, colourhascontent := false, false
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
var colourimgs, binimgs []pageimg
for _, pg := range pgs {
@@ -393,6 +479,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
for _, pg := range binimgs {
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
logger.Println("Downloading binarised page to add to PDF", pg.img)
err := conn.Download(conn.WIPStorageId(), bookname+"/"+pg.img, filepath.Join(savedir, pg.img))
if err != nil {
@@ -412,6 +505,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
}
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
if binhascontent {
fn = filepath.Join(savedir, bookname+".binarised.pdf")
err = binarisedpdf.Save(fn)
@@ -423,6 +523,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
for _, pg := range colourimgs {
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
logger.Println("Downloading colour page to add to PDF", pg.img)
colourfn := pg.img
err = conn.Download(conn.WIPStorageId(), bookname+"/"+colourfn, filepath.Join(savedir, colourfn))
@@ -448,6 +555,14 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
}
}
}
+
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
if colourhascontent {
fn = filepath.Join(savedir, bookname+".colour.pdf")
err = colourpdf.Save(fn)
@@ -458,6 +573,13 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
up <- fn
}
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
logger.Println("Creating graph")
fn = filepath.Join(savedir, "graph.png")
f, err = os.Create(fn)
@@ -474,6 +596,14 @@ func Analyse(conn Downloader) func(chan string, chan string, chan error, *log.Lo
errc <- fmt.Errorf("Error rendering graph: %s", err)
return
}
+
+ select {
+ case <-ctx.Done():
+ errc <- ctx.Err()
+ return
+ default:
+ }
+
if err == nil {
up <- fn
}
@@ -546,7 +676,7 @@ func allOCRed(bookname string, conn Lister) bool {
// OcrPage OCRs a page based on a message. It may make sense to
// roll this back into processBook (on which it is based) once
// working well.
-func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, chan string, chan error, *log.Logger), fromQueue string, toQueue string) error {
+func OcrPage(ctx context.Context, msg bookpipeline.Qmsg, conn Pipeliner, process func(context.Context, chan string, chan string, chan error, *log.Logger), fromQueue string, toQueue string) error {
dl := make(chan string)
msgc := make(chan bookpipeline.Qmsg)
processc := make(chan string)
@@ -570,19 +700,23 @@ func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, ch
go heartbeat(conn, t, msg, fromQueue, msgc, errc)
// these functions will do their jobs when their channels have data
- go download(dl, processc, conn, d, errc, conn.GetLogger())
- go process(processc, upc, errc, conn.GetLogger())
- go up(upc, done, conn, bookname, errc, conn.GetLogger())
+ go download(ctx, dl, processc, conn, d, errc, conn.GetLogger())
+ go process(ctx, processc, upc, errc, conn.GetLogger())
+ go up(ctx, upc, done, conn, bookname, errc, conn.GetLogger())
dl <- msgparts[0]
close(dl)
- // wait for either the done or errc channel to be sent to
+ // wait for either the done or errc channels to be sent to
select {
case err = <-errc:
t.Stop()
_ = os.RemoveAll(d)
return err
+ case <-ctx.Done():
+ t.Stop()
+ _ = os.RemoveAll(d)
+ return ctx.Err()
case <-done:
}
@@ -624,7 +758,7 @@ func OcrPage(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, ch
return nil
}
-func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string, chan string, chan error, *log.Logger), match *regexp.Regexp, fromQueue string, toQueue string) error {
+func ProcessBook(ctx context.Context, msg bookpipeline.Qmsg, conn Pipeliner, process func(context.Context, chan string, chan string, chan error, *log.Logger), match *regexp.Regexp, fromQueue string, toQueue string) error {
dl := make(chan string)
msgc := make(chan bookpipeline.Qmsg)
processc := make(chan string)
@@ -650,12 +784,12 @@ func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string
go heartbeat(conn, t, msg, fromQueue, msgc, errc)
// these functions will do their jobs when their channels have data
- go download(dl, processc, conn, d, errc, conn.GetLogger())
- go process(processc, upc, errc, conn.GetLogger())
+ go download(ctx, dl, processc, conn, d, errc, conn.GetLogger())
+ go process(ctx, processc, upc, errc, conn.GetLogger())
if toQueue == conn.OCRPageQueueId() {
- go upAndQueue(upc, done, toQueue, conn, bookname, training, errc, conn.GetLogger())
+ go upAndQueue(ctx, upc, done, toQueue, conn, bookname, training, errc, conn.GetLogger())
} else {
- go up(upc, done, conn, bookname, errc, conn.GetLogger())
+ go up(ctx, upc, done, conn, bookname, errc, conn.GetLogger())
}
conn.Log("Getting list of objects to download")
@@ -716,6 +850,10 @@ func ProcessBook(msg bookpipeline.Qmsg, conn Pipeliner, process func(chan string
}
}
return err
+ case <-ctx.Done():
+ t.Stop()
+ _ = os.RemoveAll(d)
+ return ctx.Err()
case <-done:
}