From bd8688f692bc8cf8836893bf42fdc46e8fb236bf Mon Sep 17 00:00:00 2001 From: Nick White Date: Tue, 20 Aug 2019 10:25:29 +0100 Subject: Split aws implementation from main.go in pipelinepreprocess --- pipelinepreprocess/aws.go | 205 +++++++++++++++++++++++++++++++++++++++++ pipelinepreprocess/main.go | 222 --------------------------------------------- 2 files changed, 205 insertions(+), 222 deletions(-) create mode 100644 pipelinepreprocess/aws.go diff --git a/pipelinepreprocess/aws.go b/pipelinepreprocess/aws.go new file mode 100644 index 0000000..bb969ed --- /dev/null +++ b/pipelinepreprocess/aws.go @@ -0,0 +1,205 @@ +package main + +import ( + "errors" + "fmt" + "log" + "os" + "regexp" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go/service/sqs" +) + +const PreprocPattern = `_bin[0-9].[0-9].png` +const HeartbeatTime = 60 + +type awsConn struct { + // these need to be set before running Init() + region string + logger *log.Logger + + // these are used internally + sess *session.Session + s3svc *s3.S3 + sqssvc *sqs.SQS + downloader *s3manager.Downloader + uploader *s3manager.Uploader + prequrl, ocrqurl string +} + +func (a *awsConn) Init() error { + if a.region == "" { + return errors.New("No region set") + } + if a.logger == nil { + return errors.New("No logger set") + } + + var err error + a.sess, err = session.NewSession(&aws.Config{ + Region: aws.String(a.region), + }) + if err != nil { + return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err)) + } + a.s3svc = s3.New(a.sess) + a.sqssvc = sqs.New(a.sess) + a.downloader = s3manager.NewDownloader(a.sess) + a.uploader = s3manager.NewUploader(a.sess) + + a.logger.Println("Getting preprocess queue URL") + result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ + QueueName: aws.String("rescribepreprocess"), + }) + if err != nil { + return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err)) + } + a.prequrl = *result.QueueUrl + + a.logger.Println("Getting OCR queue URL") + result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ + QueueName: aws.String("rescribeocr"), + }) + if err != nil { + return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err)) + } + a.ocrqurl = *result.QueueUrl + return nil +} + +func (a *awsConn) CheckQueue(url string) (Qmsg, error) { + msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ + MaxNumberOfMessages: aws.Int64(1), + VisibilityTimeout: aws.Int64(HeartbeatTime * 2), + WaitTimeSeconds: aws.Int64(20), + QueueUrl: &url, + }) + if err != nil { + return Qmsg{}, err + } + + if len(msgResult.Messages) > 0 { + msg := Qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body } + a.logger.Println("Message received:", msg.Body) + return msg, nil + } else { + return Qmsg{}, nil + } +} + +func (a *awsConn) CheckPreQueue() (Qmsg, error) { + a.logger.Println("Checking preprocessing queue for new messages") + return a.CheckQueue(a.prequrl) +} + +func (a *awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error { + for _ = range t.C { + duration := int64(HeartbeatTime * 2) + _, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ + ReceiptHandle: &msgHandle, + QueueUrl: &qurl, + VisibilityTimeout: &duration, + }) + if err != nil { + return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err)) + } + } + return nil +} + +func (a *awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error { + a.logger.Println("Starting preprocess queue heartbeat for", msgHandle) + return a.QueueHeartbeat(t, msgHandle, a.prequrl) +} + +func (a *awsConn) ListObjects(bucket string, prefix string, names chan string) error { + alreadydone := regexp.MustCompile(PreprocPattern) + err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + }, func(page *s3.ListObjectsV2Output, last bool) bool { + for _, r := range page.Contents { + if alreadydone.MatchString(*r.Key) { + a.logger.Println("Skipping item that looks like it has already been processed", *r.Key) + continue + } + names <- *r.Key + } + return true + }) + close(names) + return err +} + +func (a *awsConn) ListInProgress(bookname string, names chan string) error { + return a.ListObjects("rescribeinprogress", bookname, names) +} + +func (a *awsConn) AddToQueue(url string, msg string) error { + _, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{ + MessageBody: &msg, + QueueUrl: &url, + }) + return err +} + +func (a *awsConn) AddToOCRQueue(msg string) error { + return a.AddToQueue(a.ocrqurl, msg) +} + +func (a *awsConn) DelFromQueue(url string, handle string) error { + _, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ + QueueUrl: &url, + ReceiptHandle: &handle, + }) + return err +} + +func (a *awsConn) DelFromPreQueue(handle string) error { + return a.DelFromQueue(a.prequrl, handle) +} + +func (a *awsConn) Download(bucket string, key string, path string) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + _, err = a.downloader.Download(f, + &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: &key, + }) + return err +} + +func (a *awsConn) DownloadFromInProgress(key string, path string) error { + a.logger.Println("Downloading", key) + return a.Download("rescribeinprogress", key, path) +} + +func (a *awsConn) Upload(bucket string, key string, path string) error { + file, err := os.Open(path) + if err != nil { + log.Fatalln("Failed to open file", path, err) + } + defer file.Close() + + _, err = a.uploader.Upload(&s3manager.UploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Body: file, + }) + return err +} + +func (a *awsConn) UploadToInProgress(key string, path string) error { + a.logger.Println("Uploading", path) + return a.Upload("rescribeinprogress", key, path) +} diff --git a/pipelinepreprocess/main.go b/pipelinepreprocess/main.go index 407591f..a223d0b 100644 --- a/pipelinepreprocess/main.go +++ b/pipelinepreprocess/main.go @@ -6,20 +6,11 @@ package main // TODO: check if images are prebinarised and if so skip multiple binarisation import ( - "errors" - "fmt" "log" "os" "path/filepath" - "regexp" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" - "github.com/aws/aws-sdk-go/service/sqs" - "rescribe.xyz/go.git/preproc" ) @@ -31,17 +22,8 @@ func (w NullWriter) Write(p []byte) (n int, err error) { return len(p), nil } -var alreadydone *regexp.Regexp - -const HeartbeatTime = 60 const PauseBetweenChecks = 60 * time.Second -const PreprocPattern = `_bin[0-9].[0-9].png` -// TODO: could restructure like so: -// have the goroutine functions run outside of the main loop in the program, -// so use them for multiple books indefinitely. would require finding a way to -// signal when the queues need to be updated (e.g. when a book is finished) -// // TODO: consider having the download etc functions return a channel like a generator, like in rob pike's talk type Clouder interface { @@ -70,194 +52,6 @@ type Qmsg struct { Handle, Body string } -type awsConn struct { - // these need to be set before running Init() - region string - logger *log.Logger - - // these are used internally - sess *session.Session - s3svc *s3.S3 - sqssvc *sqs.SQS - downloader *s3manager.Downloader - uploader *s3manager.Uploader - prequrl, ocrqurl string -} - -func (a *awsConn) Init() error { - if a.region == "" { - return errors.New("No region set") - } - if a.logger == nil { - return errors.New("No logger set") - } - - var err error - a.sess, err = session.NewSession(&aws.Config{ - Region: aws.String(a.region), - }) - if err != nil { - return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err)) - } - a.s3svc = s3.New(a.sess) - a.sqssvc = sqs.New(a.sess) - a.downloader = s3manager.NewDownloader(a.sess) - a.uploader = s3manager.NewUploader(a.sess) - - a.logger.Println("Getting preprocess queue URL") - result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ - QueueName: aws.String("rescribepreprocess"), - }) - if err != nil { - return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err)) - } - a.prequrl = *result.QueueUrl - - a.logger.Println("Getting OCR queue URL") - result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ - QueueName: aws.String("rescribeocr"), - }) - if err != nil { - return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err)) - } - a.ocrqurl = *result.QueueUrl - return nil -} - -func (a *awsConn) CheckQueue(url string) (Qmsg, error) { - msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ - MaxNumberOfMessages: aws.Int64(1), - VisibilityTimeout: aws.Int64(HeartbeatTime * 2), - WaitTimeSeconds: aws.Int64(20), - QueueUrl: &url, - }) - if err != nil { - return Qmsg{}, err - } - - if len(msgResult.Messages) > 0 { - msg := Qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body } - a.logger.Println("Message received:", msg.Body) - return msg, nil - } else { - return Qmsg{}, nil - } -} - -func (a *awsConn) CheckPreQueue() (Qmsg, error) { - a.logger.Println("Checking preprocessing queue for new messages") - return a.CheckQueue(a.prequrl) -} - -func (a *awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error { - for _ = range t.C { - duration := int64(HeartbeatTime * 2) - _, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ - ReceiptHandle: &msgHandle, - QueueUrl: &qurl, - VisibilityTimeout: &duration, - }) - if err != nil { - return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err)) - } - } - return nil -} - -func (a *awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error { - a.logger.Println("Starting preprocess queue heartbeat for", msgHandle) - return a.QueueHeartbeat(t, msgHandle, a.prequrl) -} - -func (a *awsConn) ListObjects(bucket string, prefix string, names chan string) error { - err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ - Bucket: aws.String(bucket), - Prefix: aws.String(prefix), - }, func(page *s3.ListObjectsV2Output, last bool) bool { - for _, r := range page.Contents { - if alreadydone.MatchString(*r.Key) { - a.logger.Println("Skipping item that looks like it has already been processed", *r.Key) - continue - } - names <- *r.Key - } - return true - }) - close(names) - return err -} - -func (a *awsConn) ListInProgress(bookname string, names chan string) error { - return a.ListObjects("rescribeinprogress", bookname, names) -} - -func (a *awsConn) AddToQueue(url string, msg string) error { - _, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{ - MessageBody: &msg, - QueueUrl: &url, - }) - return err -} - -func (a *awsConn) AddToOCRQueue(msg string) error { - return a.AddToQueue(a.ocrqurl, msg) -} - -func (a *awsConn) DelFromQueue(url string, handle string) error { - _, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ - QueueUrl: &url, - ReceiptHandle: &handle, - }) - return err -} - -func (a *awsConn) DelFromPreQueue(handle string) error { - return a.DelFromQueue(a.prequrl, handle) -} - -func (a *awsConn) Download(bucket string, key string, path string) error { - f, err := os.Create(path) - if err != nil { - return err - } - defer f.Close() - - _, err = a.downloader.Download(f, - &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: &key, - }) - return err -} - -func (a *awsConn) DownloadFromInProgress(key string, path string) error { - a.logger.Println("Downloading", key) - return a.Download("rescribeinprogress", key, path) -} - -func (a *awsConn) Upload(bucket string, key string, path string) error { - file, err := os.Open(path) - if err != nil { - log.Fatalln("Failed to open file", path, err) - } - defer file.Close() - - _, err = a.uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - Body: file, - }) - return err -} - -func (a *awsConn) UploadToInProgress(key string, path string) error { - a.logger.Println("Uploading", path) - return a.Upload("rescribeinprogress", key, path) -} - - - - func download(dl chan string, pre chan string, conn Pipeliner, dir string) { for key := range dl { fn := filepath.Join(dir, filepath.Base(key)) @@ -297,20 +91,6 @@ func up(c chan string, done chan bool, conn Pipeliner, bookname string) { done <- true } -func heartbeat(h *time.Ticker, msgHandle string, qurl string, sqssvc *sqs.SQS) { - for _ = range h.C { - duration := int64(HeartbeatTime * 2) - _, err := sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ - ReceiptHandle: &msgHandle, - QueueUrl: &qurl, - VisibilityTimeout: &duration, - }) - if err != nil { - log.Fatalln("Error updating queue duration:", err) - } - } -} - func main() { var verboselog *log.Logger if len(os.Args) > 1 { @@ -324,8 +104,6 @@ func main() { verboselog = log.New(n, "", log.LstdFlags) } - alreadydone = regexp.MustCompile(PreprocPattern) - var conn Pipeliner conn = &awsConn{ region: "eu-west-2", logger: verboselog } -- cgit v1.2.1-24-ge1ad