diff options
author | Nick White <git@njw.name> | 2019-08-19 17:01:18 +0100 |
---|---|---|
committer | Nick White <git@njw.name> | 2019-08-19 17:01:18 +0100 |
commit | 79aa0597bb1a21c2078fe0d3f27d9037604d91bf (patch) | |
tree | 15108d8d98765a2d34d165f21fad6a37935d4558 /pipelinepreprocess/main.go | |
parent | 312dcbe96e45330e933f7d542e3b2ef2bf76ec08 (diff) |
Work in progress rearchitecture to use interfaces; currently pointers are screwy causing segfaults
Diffstat (limited to 'pipelinepreprocess/main.go')
-rw-r--r-- | pipelinepreprocess/main.go | 352 |
1 files changed, 250 insertions, 102 deletions
diff --git a/pipelinepreprocess/main.go b/pipelinepreprocess/main.go index 6c58d98..c075fa9 100644 --- a/pipelinepreprocess/main.go +++ b/pipelinepreprocess/main.go @@ -6,6 +6,8 @@ package main // TODO: check if images are prebinarised and if so skip multiple binarisation import ( + "errors" + "fmt" "log" "os" "path/filepath" @@ -28,7 +30,8 @@ type NullWriter bool func (w NullWriter) Write(p []byte) (n int, err error) { return len(p), nil } -var verboselog *log.Logger + +var alreadydone *regexp.Regexp const HeartbeatTime = 60 const PauseBetweenChecks = 60 * time.Second @@ -44,20 +47,225 @@ const PreprocPattern = `_bin[0-9].[0-9].png` // // TODO: consider having the download etc functions return a channel like a generator, like in rob pike's talk -func download(dl chan string, pre chan string, downloader *s3manager.Downloader, dir string) { - for key := range dl { - verboselog.Println("Downloading", key) - fn := filepath.Join(dir, filepath.Base(key)) - f, err := os.Create(fn) +type Clouder interface { + Init() error + ListObjects(bucket string, prefix string, names chan string) error + Download(bucket string, key string, fn string) error + Upload(bucket string, key string, path string) error + CheckQueue(url string) (qmsg, error) + AddToQueue(url string, msg string) error + DelFromQueue(url string, handle string) error + QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error +} + +type Pipeliner interface { + Clouder + ListInProgress(bookname string, names chan string) error + DownloadFromInProgress(key string, fn string) error + UploadToInProgress(key string, path string) error + CheckPreQueue() (qmsg, error) + AddToOCRQueue(msg string) error + DelFromPreQueue(handle string) error + PreQueueHeartbeat(t *time.Ticker, msgHandle string) error +} + +type qmsg struct { + Handle, Body string +} + +type awsConn struct { + // these need to be set before running Init() + region string + logger *log.Logger + + // these are used internally + sess *session.Session + s3svc *s3.S3 + sqssvc *sqs.SQS + downloader *s3manager.Downloader + uploader *s3manager.Uploader + prequrl, ocrqurl string +} + +func (a awsConn) Init() error { + if a.region == "" { + return errors.New("No region set") + } + if a.logger == nil { + return errors.New("No logger set") + } + + var err error + a.sess, err = session.NewSession(&aws.Config{ + Region: aws.String(a.region), + }) + if err != nil { + return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err)) + } + a.s3svc = s3.New(a.sess) + a.sqssvc = sqs.New(a.sess) + a.downloader = s3manager.NewDownloader(a.sess) + a.uploader = s3manager.NewUploader(a.sess) + + a.logger.Println("Getting preprocess queue URL") + result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ + QueueName: aws.String("rescribepreprocess"), + }) + if err != nil { + return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err)) + } + a.prequrl = *result.QueueUrl + a.logger.Println("preprocess queue URL", a.prequrl) + + a.logger.Println("Getting OCR queue URL") + result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ + QueueName: aws.String("rescribeocr"), + }) + if err != nil { + return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err)) + } + a.ocrqurl = *result.QueueUrl + return nil +} + +func (a awsConn) CheckQueue(url string) (qmsg, error) { + msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ + MaxNumberOfMessages: aws.Int64(1), + VisibilityTimeout: aws.Int64(HeartbeatTime * 2), + WaitTimeSeconds: aws.Int64(20), + QueueUrl: &url, + }) + if err != nil { + return qmsg{}, err + } + + if len(msgResult.Messages) > 0 { + msg := qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body } + a.logger.Println("Message received:", msg.Body) + return msg, nil + } else { + return qmsg{}, nil + } +} + +func (a awsConn) CheckPreQueue() (qmsg, error) { + a.logger.Println("Checking preprocessing queue for new messages:", a.prequrl) + return a.CheckQueue(a.prequrl) +} + +func (a awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error { + for _ = range t.C { + duration := int64(HeartbeatTime * 2) + _, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ + ReceiptHandle: &msgHandle, + QueueUrl: &qurl, + VisibilityTimeout: &duration, + }) if err != nil { - log.Fatalln("Failed to create file", fn, err) + return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err)) } - defer f.Close() + } + return nil +} + +func (a awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error { + a.logger.Println("Starting preprocess queue heartbeat for", msgHandle) + return a.QueueHeartbeat(t, msgHandle, a.prequrl) +} + +func (a awsConn) ListObjects(bucket string, prefix string, names chan string) error { + err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + }, func(page *s3.ListObjectsV2Output, last bool) bool { + for _, r := range page.Contents { + if alreadydone.MatchString(*r.Key) { + a.logger.Println("Skipping item that looks like it has already been processed", *r.Key) + continue + } + names <- *r.Key + } + return true + }) + close(names) + return err +} + +func (a awsConn) ListInProgress(bookname string, names chan string) error { + return a.ListObjects("rescribeinprogress", bookname, names) +} + +func (a awsConn) AddToQueue(url string, msg string) error { + _, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{ + MessageBody: &msg, + QueueUrl: &url, + }) + return err +} + +func (a awsConn) AddToOCRQueue(msg string) error { + return a.AddToQueue(a.ocrqurl, msg) +} - _, err = downloader.Download(f, - &s3.GetObjectInput{ - Bucket: aws.String("rescribeinprogress"), - Key: &key }) +func (a awsConn) DelFromQueue(url string, handle string) error { + _, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ + QueueUrl: &url, + ReceiptHandle: &handle, + }) + return err +} + +func (a awsConn) DelFromPreQueue(handle string) error { + return a.DelFromQueue(a.prequrl, handle) +} + +func (a awsConn) Download(bucket string, key string, path string) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + _, err = a.downloader.Download(f, + &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: &key, + }) + return err +} + +func (a awsConn) DownloadFromInProgress(key string, path string) error { + a.logger.Println("Downloading", key) + return a.Download("rescribeinprogress", key, path) +} + +func (a awsConn) Upload(bucket string, key string, path string) error { + file, err := os.Open(path) + if err != nil { + log.Fatalln("Failed to open file", path, err) + } + defer file.Close() + + _, err = a.uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: file, + }) + return err +} + +func (a awsConn) UploadToInProgress(key string, path string) error { + a.logger.Println("Uploading", path) + return a.Upload("rescribeinprogress", key, path) +} + + + + +func download(dl chan string, pre chan string, conn Pipeliner, dir string) { + for key := range dl { + fn := filepath.Join(dir, filepath.Base(key)) + err := conn.DownloadFromInProgress(key, fn) if err != nil { log.Fatalln("Failed to download", key, err) } @@ -66,9 +274,9 @@ func download(dl chan string, pre chan string, downloader *s3manager.Downloader, close(pre) } -func preprocess(pre chan string, up chan string) { +func preprocess(pre chan string, up chan string, logger *log.Logger) { for path := range pre { - verboselog.Println("Preprocessing", path) + logger.Println("Preprocessing", path) done, err := preproc.PreProcMulti(path, []float64{0.1, 0.2, 0.4, 0.5}, "binary", 0, true, 5, 30) if err != nil { log.Fatalln("Error preprocessing", path, err) @@ -80,21 +288,11 @@ func preprocess(pre chan string, up chan string) { close(up) } -func up(c chan string, done chan bool, uploader *s3manager.Uploader, bookname string) { +func up(c chan string, done chan bool, conn Pipeliner, bookname string) { for path := range c { - verboselog.Println("Uploading", path) name := filepath.Base(path) - file, err := os.Open(path) - if err != nil { - log.Fatalln("Failed to open file", path, err) - } - defer file.Close() - - _, err = uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String("rescribeinprogress"), - Key: aws.String(filepath.Join(bookname, name)), - Body: file, - }) + key := filepath.Join(bookname, name) + err := conn.UploadToInProgress(key, path) if err != nil { log.Fatalln("Failed to upload", path, err) } @@ -118,6 +316,7 @@ func heartbeat(h *time.Ticker, msgHandle string, qurl string, sqssvc *sqs.SQS) { } func main() { + var verboselog *log.Logger if len(os.Args) > 1 { if os.Args[1] == "-v" { verboselog = log.New(os.Stdout, "", log.LstdFlags) @@ -129,65 +328,31 @@ func main() { verboselog = log.New(n, "", log.LstdFlags) } - alreadydone := regexp.MustCompile(PreprocPattern) + alreadydone = regexp.MustCompile(PreprocPattern) - verboselog.Println("Setting up AWS session") - sess, err := session.NewSession(&aws.Config{ - Region: aws.String("eu-west-2"), - }) - if err != nil { - log.Fatalln("Error: failed to set up aws session:", err) - } - s3svc := s3.New(sess) - sqssvc := sqs.New(sess) - downloader := s3manager.NewDownloader(sess) - uploader := s3manager.NewUploader(sess) - - preqname := "rescribepreprocess" - verboselog.Println("Getting Queue URL for", preqname) - result, err := sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ - QueueName: aws.String(preqname), - }) - if err != nil { - log.Fatalln("Error getting queue URL for", preqname, ":", err) - } - prequrl := *result.QueueUrl + var conn Pipeliner + conn = awsConn{ region: "eu-west-2", logger: verboselog } - ocrqname := "rescribeocr" - verboselog.Println("Getting Queue URL for", ocrqname) - result, err = sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ - QueueName: aws.String(ocrqname), - }) + verboselog.Println("Setting up AWS session") + err := conn.Init() if err != nil { - log.Fatalln("Error getting queue URL for", ocrqname, ":", err) + log.Fatalln("Error setting up cloud connection:", err) } - ocrqurl := *result.QueueUrl for { - verboselog.Println("Checking preprocessing queue for new messages") - msgResult, err := sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ - MaxNumberOfMessages: aws.Int64(1), - VisibilityTimeout: aws.Int64(HeartbeatTime * 2), - WaitTimeSeconds: aws.Int64(20), - QueueUrl: &prequrl, - }) + msg, err := conn.CheckPreQueue() if err != nil { - log.Fatalln("Error checking queue", preqname, ":", err) + log.Fatalln("Error checking preprocess queue", err) } - - var bookname string - if len(msgResult.Messages) > 0 { - bookname = *msgResult.Messages[0].Body - verboselog.Println("Message received:", bookname) - } else { + if msg.Handle == "" { verboselog.Println("No message received, sleeping") time.Sleep(PauseBetweenChecks) continue } + bookname := msg.Body - verboselog.Println("Starting heartbeat every", HeartbeatTime, "seconds") t := time.NewTicker(HeartbeatTime * time.Second) - go heartbeat(t, *msgResult.Messages[0].ReceiptHandle, prequrl, sqssvc) + go conn.PreQueueHeartbeat(t, msg.Handle) d := filepath.Join(os.TempDir(), bookname) @@ -202,49 +367,32 @@ func main() { done := make(chan bool) // this is just to communicate when up has finished, so the queues can be updated // these functions will do their jobs when their channels have data - go download(dl, pre, downloader, d) - go preprocess(pre, upc) - go up(upc, done, uploader, bookname) + go download(dl, pre, conn, d) + go preprocess(pre, upc, verboselog) + go up(upc, done, conn, bookname) verboselog.Println("Getting list of objects to download") - err = s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ - Bucket: aws.String("rescribeinprogress"), - Prefix: aws.String(bookname), - }, func(page *s3.ListObjectsV2Output, last bool) bool { - for _, r := range page.Contents { - if alreadydone.MatchString(*r.Key) { - verboselog.Println("Skipping item that looks like it has already been processed", *r.Key) - continue - } - dl <- *r.Key - } - return true - }) - close(dl) - + err = conn.ListInProgress(bookname, dl) + if err != nil { + log.Fatalln("Failed to get list of files for book", bookname, err) + } // wait for the done channel to be posted to <-done - verboselog.Println("Sending", bookname, "to queue", ocrqurl) - _, err = sqssvc.SendMessage(&sqs.SendMessageInput{ - MessageBody: aws.String(bookname), - QueueUrl: &ocrqurl, - }) + verboselog.Println("Sending", bookname, "to OCR queue") + err = conn.AddToOCRQueue(bookname) if err != nil { - log.Fatalln("Error sending message to queue", ocrqname, ":", err) + log.Fatalln("Error adding to ocr queue", bookname, err) } t.Stop() - verboselog.Println("Deleting original message from queue", prequrl) - _, err = sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ - QueueUrl: &prequrl, - ReceiptHandle: msgResult.Messages[0].ReceiptHandle, - }) + verboselog.Println("Deleting original message from preprocessing queue") + err = conn.DelFromPreQueue(msg.Handle) if err != nil { - log.Fatalln("Error deleting message from queue", preqname, ":", err) + log.Fatalln("Error deleting message from preprocessing queue", err) } err = os.RemoveAll(d) |