summaryrefslogtreecommitdiff
path: root/pipelinepreprocess/main.go
diff options
context:
space:
mode:
authorNick White <git@njw.name>2019-08-19 17:01:18 +0100
committerNick White <git@njw.name>2019-08-19 17:01:18 +0100
commit79aa0597bb1a21c2078fe0d3f27d9037604d91bf (patch)
tree15108d8d98765a2d34d165f21fad6a37935d4558 /pipelinepreprocess/main.go
parent312dcbe96e45330e933f7d542e3b2ef2bf76ec08 (diff)
Work in progress rearchitecture to use interfaces; currently pointers are screwy causing segfaults
Diffstat (limited to 'pipelinepreprocess/main.go')
-rw-r--r--pipelinepreprocess/main.go352
1 files changed, 250 insertions, 102 deletions
diff --git a/pipelinepreprocess/main.go b/pipelinepreprocess/main.go
index 6c58d98..c075fa9 100644
--- a/pipelinepreprocess/main.go
+++ b/pipelinepreprocess/main.go
@@ -6,6 +6,8 @@ package main
// TODO: check if images are prebinarised and if so skip multiple binarisation
import (
+ "errors"
+ "fmt"
"log"
"os"
"path/filepath"
@@ -28,7 +30,8 @@ type NullWriter bool
func (w NullWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
-var verboselog *log.Logger
+
+var alreadydone *regexp.Regexp
const HeartbeatTime = 60
const PauseBetweenChecks = 60 * time.Second
@@ -44,20 +47,225 @@ const PreprocPattern = `_bin[0-9].[0-9].png`
//
// TODO: consider having the download etc functions return a channel like a generator, like in rob pike's talk
-func download(dl chan string, pre chan string, downloader *s3manager.Downloader, dir string) {
- for key := range dl {
- verboselog.Println("Downloading", key)
- fn := filepath.Join(dir, filepath.Base(key))
- f, err := os.Create(fn)
+type Clouder interface {
+ Init() error
+ ListObjects(bucket string, prefix string, names chan string) error
+ Download(bucket string, key string, fn string) error
+ Upload(bucket string, key string, path string) error
+ CheckQueue(url string) (qmsg, error)
+ AddToQueue(url string, msg string) error
+ DelFromQueue(url string, handle string) error
+ QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error
+}
+
+type Pipeliner interface {
+ Clouder
+ ListInProgress(bookname string, names chan string) error
+ DownloadFromInProgress(key string, fn string) error
+ UploadToInProgress(key string, path string) error
+ CheckPreQueue() (qmsg, error)
+ AddToOCRQueue(msg string) error
+ DelFromPreQueue(handle string) error
+ PreQueueHeartbeat(t *time.Ticker, msgHandle string) error
+}
+
+type qmsg struct {
+ Handle, Body string
+}
+
+type awsConn struct {
+ // these need to be set before running Init()
+ region string
+ logger *log.Logger
+
+ // these are used internally
+ sess *session.Session
+ s3svc *s3.S3
+ sqssvc *sqs.SQS
+ downloader *s3manager.Downloader
+ uploader *s3manager.Uploader
+ prequrl, ocrqurl string
+}
+
+func (a awsConn) Init() error {
+ if a.region == "" {
+ return errors.New("No region set")
+ }
+ if a.logger == nil {
+ return errors.New("No logger set")
+ }
+
+ var err error
+ a.sess, err = session.NewSession(&aws.Config{
+ Region: aws.String(a.region),
+ })
+ if err != nil {
+ return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err))
+ }
+ a.s3svc = s3.New(a.sess)
+ a.sqssvc = sqs.New(a.sess)
+ a.downloader = s3manager.NewDownloader(a.sess)
+ a.uploader = s3manager.NewUploader(a.sess)
+
+ a.logger.Println("Getting preprocess queue URL")
+ result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
+ QueueName: aws.String("rescribepreprocess"),
+ })
+ if err != nil {
+ return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err))
+ }
+ a.prequrl = *result.QueueUrl
+ a.logger.Println("preprocess queue URL", a.prequrl)
+
+ a.logger.Println("Getting OCR queue URL")
+ result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
+ QueueName: aws.String("rescribeocr"),
+ })
+ if err != nil {
+ return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err))
+ }
+ a.ocrqurl = *result.QueueUrl
+ return nil
+}
+
+func (a awsConn) CheckQueue(url string) (qmsg, error) {
+ msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
+ MaxNumberOfMessages: aws.Int64(1),
+ VisibilityTimeout: aws.Int64(HeartbeatTime * 2),
+ WaitTimeSeconds: aws.Int64(20),
+ QueueUrl: &url,
+ })
+ if err != nil {
+ return qmsg{}, err
+ }
+
+ if len(msgResult.Messages) > 0 {
+ msg := qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body }
+ a.logger.Println("Message received:", msg.Body)
+ return msg, nil
+ } else {
+ return qmsg{}, nil
+ }
+}
+
+func (a awsConn) CheckPreQueue() (qmsg, error) {
+ a.logger.Println("Checking preprocessing queue for new messages:", a.prequrl)
+ return a.CheckQueue(a.prequrl)
+}
+
+func (a awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error {
+ for _ = range t.C {
+ duration := int64(HeartbeatTime * 2)
+ _, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
+ ReceiptHandle: &msgHandle,
+ QueueUrl: &qurl,
+ VisibilityTimeout: &duration,
+ })
if err != nil {
- log.Fatalln("Failed to create file", fn, err)
+ return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err))
}
- defer f.Close()
+ }
+ return nil
+}
+
+func (a awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error {
+ a.logger.Println("Starting preprocess queue heartbeat for", msgHandle)
+ return a.QueueHeartbeat(t, msgHandle, a.prequrl)
+}
+
+func (a awsConn) ListObjects(bucket string, prefix string, names chan string) error {
+ err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
+ Bucket: aws.String(bucket),
+ Prefix: aws.String(prefix),
+ }, func(page *s3.ListObjectsV2Output, last bool) bool {
+ for _, r := range page.Contents {
+ if alreadydone.MatchString(*r.Key) {
+ a.logger.Println("Skipping item that looks like it has already been processed", *r.Key)
+ continue
+ }
+ names <- *r.Key
+ }
+ return true
+ })
+ close(names)
+ return err
+}
+
+func (a awsConn) ListInProgress(bookname string, names chan string) error {
+ return a.ListObjects("rescribeinprogress", bookname, names)
+}
+
+func (a awsConn) AddToQueue(url string, msg string) error {
+ _, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{
+ MessageBody: &msg,
+ QueueUrl: &url,
+ })
+ return err
+}
+
+func (a awsConn) AddToOCRQueue(msg string) error {
+ return a.AddToQueue(a.ocrqurl, msg)
+}
- _, err = downloader.Download(f,
- &s3.GetObjectInput{
- Bucket: aws.String("rescribeinprogress"),
- Key: &key })
+func (a awsConn) DelFromQueue(url string, handle string) error {
+ _, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{
+ QueueUrl: &url,
+ ReceiptHandle: &handle,
+ })
+ return err
+}
+
+func (a awsConn) DelFromPreQueue(handle string) error {
+ return a.DelFromQueue(a.prequrl, handle)
+}
+
+func (a awsConn) Download(bucket string, key string, path string) error {
+ f, err := os.Create(path)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+
+ _, err = a.downloader.Download(f,
+ &s3.GetObjectInput{
+ Bucket: aws.String(bucket),
+ Key: &key,
+ })
+ return err
+}
+
+func (a awsConn) DownloadFromInProgress(key string, path string) error {
+ a.logger.Println("Downloading", key)
+ return a.Download("rescribeinprogress", key, path)
+}
+
+func (a awsConn) Upload(bucket string, key string, path string) error {
+ file, err := os.Open(path)
+ if err != nil {
+ log.Fatalln("Failed to open file", path, err)
+ }
+ defer file.Close()
+
+ _, err = a.uploader.Upload(&s3manager.UploadInput{
+ Bucket: aws.String(bucket),
+ Key: aws.String(key),
+ Body: file,
+ })
+ return err
+}
+
+func (a awsConn) UploadToInProgress(key string, path string) error {
+ a.logger.Println("Uploading", path)
+ return a.Upload("rescribeinprogress", key, path)
+}
+
+
+
+
+func download(dl chan string, pre chan string, conn Pipeliner, dir string) {
+ for key := range dl {
+ fn := filepath.Join(dir, filepath.Base(key))
+ err := conn.DownloadFromInProgress(key, fn)
if err != nil {
log.Fatalln("Failed to download", key, err)
}
@@ -66,9 +274,9 @@ func download(dl chan string, pre chan string, downloader *s3manager.Downloader,
close(pre)
}
-func preprocess(pre chan string, up chan string) {
+func preprocess(pre chan string, up chan string, logger *log.Logger) {
for path := range pre {
- verboselog.Println("Preprocessing", path)
+ logger.Println("Preprocessing", path)
done, err := preproc.PreProcMulti(path, []float64{0.1, 0.2, 0.4, 0.5}, "binary", 0, true, 5, 30)
if err != nil {
log.Fatalln("Error preprocessing", path, err)
@@ -80,21 +288,11 @@ func preprocess(pre chan string, up chan string) {
close(up)
}
-func up(c chan string, done chan bool, uploader *s3manager.Uploader, bookname string) {
+func up(c chan string, done chan bool, conn Pipeliner, bookname string) {
for path := range c {
- verboselog.Println("Uploading", path)
name := filepath.Base(path)
- file, err := os.Open(path)
- if err != nil {
- log.Fatalln("Failed to open file", path, err)
- }
- defer file.Close()
-
- _, err = uploader.Upload(&s3manager.UploadInput{
- Bucket: aws.String("rescribeinprogress"),
- Key: aws.String(filepath.Join(bookname, name)),
- Body: file,
- })
+ key := filepath.Join(bookname, name)
+ err := conn.UploadToInProgress(key, path)
if err != nil {
log.Fatalln("Failed to upload", path, err)
}
@@ -118,6 +316,7 @@ func heartbeat(h *time.Ticker, msgHandle string, qurl string, sqssvc *sqs.SQS) {
}
func main() {
+ var verboselog *log.Logger
if len(os.Args) > 1 {
if os.Args[1] == "-v" {
verboselog = log.New(os.Stdout, "", log.LstdFlags)
@@ -129,65 +328,31 @@ func main() {
verboselog = log.New(n, "", log.LstdFlags)
}
- alreadydone := regexp.MustCompile(PreprocPattern)
+ alreadydone = regexp.MustCompile(PreprocPattern)
- verboselog.Println("Setting up AWS session")
- sess, err := session.NewSession(&aws.Config{
- Region: aws.String("eu-west-2"),
- })
- if err != nil {
- log.Fatalln("Error: failed to set up aws session:", err)
- }
- s3svc := s3.New(sess)
- sqssvc := sqs.New(sess)
- downloader := s3manager.NewDownloader(sess)
- uploader := s3manager.NewUploader(sess)
-
- preqname := "rescribepreprocess"
- verboselog.Println("Getting Queue URL for", preqname)
- result, err := sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
- QueueName: aws.String(preqname),
- })
- if err != nil {
- log.Fatalln("Error getting queue URL for", preqname, ":", err)
- }
- prequrl := *result.QueueUrl
+ var conn Pipeliner
+ conn = awsConn{ region: "eu-west-2", logger: verboselog }
- ocrqname := "rescribeocr"
- verboselog.Println("Getting Queue URL for", ocrqname)
- result, err = sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
- QueueName: aws.String(ocrqname),
- })
+ verboselog.Println("Setting up AWS session")
+ err := conn.Init()
if err != nil {
- log.Fatalln("Error getting queue URL for", ocrqname, ":", err)
+ log.Fatalln("Error setting up cloud connection:", err)
}
- ocrqurl := *result.QueueUrl
for {
- verboselog.Println("Checking preprocessing queue for new messages")
- msgResult, err := sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
- MaxNumberOfMessages: aws.Int64(1),
- VisibilityTimeout: aws.Int64(HeartbeatTime * 2),
- WaitTimeSeconds: aws.Int64(20),
- QueueUrl: &prequrl,
- })
+ msg, err := conn.CheckPreQueue()
if err != nil {
- log.Fatalln("Error checking queue", preqname, ":", err)
+ log.Fatalln("Error checking preprocess queue", err)
}
-
- var bookname string
- if len(msgResult.Messages) > 0 {
- bookname = *msgResult.Messages[0].Body
- verboselog.Println("Message received:", bookname)
- } else {
+ if msg.Handle == "" {
verboselog.Println("No message received, sleeping")
time.Sleep(PauseBetweenChecks)
continue
}
+ bookname := msg.Body
- verboselog.Println("Starting heartbeat every", HeartbeatTime, "seconds")
t := time.NewTicker(HeartbeatTime * time.Second)
- go heartbeat(t, *msgResult.Messages[0].ReceiptHandle, prequrl, sqssvc)
+ go conn.PreQueueHeartbeat(t, msg.Handle)
d := filepath.Join(os.TempDir(), bookname)
@@ -202,49 +367,32 @@ func main() {
done := make(chan bool) // this is just to communicate when up has finished, so the queues can be updated
// these functions will do their jobs when their channels have data
- go download(dl, pre, downloader, d)
- go preprocess(pre, upc)
- go up(upc, done, uploader, bookname)
+ go download(dl, pre, conn, d)
+ go preprocess(pre, upc, verboselog)
+ go up(upc, done, conn, bookname)
verboselog.Println("Getting list of objects to download")
- err = s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
- Bucket: aws.String("rescribeinprogress"),
- Prefix: aws.String(bookname),
- }, func(page *s3.ListObjectsV2Output, last bool) bool {
- for _, r := range page.Contents {
- if alreadydone.MatchString(*r.Key) {
- verboselog.Println("Skipping item that looks like it has already been processed", *r.Key)
- continue
- }
- dl <- *r.Key
- }
- return true
- })
- close(dl)
-
+ err = conn.ListInProgress(bookname, dl)
+ if err != nil {
+ log.Fatalln("Failed to get list of files for book", bookname, err)
+ }
// wait for the done channel to be posted to
<-done
- verboselog.Println("Sending", bookname, "to queue", ocrqurl)
- _, err = sqssvc.SendMessage(&sqs.SendMessageInput{
- MessageBody: aws.String(bookname),
- QueueUrl: &ocrqurl,
- })
+ verboselog.Println("Sending", bookname, "to OCR queue")
+ err = conn.AddToOCRQueue(bookname)
if err != nil {
- log.Fatalln("Error sending message to queue", ocrqname, ":", err)
+ log.Fatalln("Error adding to ocr queue", bookname, err)
}
t.Stop()
- verboselog.Println("Deleting original message from queue", prequrl)
- _, err = sqssvc.DeleteMessage(&sqs.DeleteMessageInput{
- QueueUrl: &prequrl,
- ReceiptHandle: msgResult.Messages[0].ReceiptHandle,
- })
+ verboselog.Println("Deleting original message from preprocessing queue")
+ err = conn.DelFromPreQueue(msg.Handle)
if err != nil {
- log.Fatalln("Error deleting message from queue", preqname, ":", err)
+ log.Fatalln("Error deleting message from preprocessing queue", err)
}
err = os.RemoveAll(d)