From a60577af86c64001773e31f1691bf4853efc4772 Mon Sep 17 00:00:00 2001
From: Nick White <git@njw.name>
Date: Thu, 22 Aug 2019 10:40:57 +0100
Subject: Update usage string, and comments

---
 bookpipeline/aws.go  | 269 ++++++++++++++++++++++++++++++++++++++++
 bookpipeline/main.go | 343 +++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 612 insertions(+)
 create mode 100644 bookpipeline/aws.go
 create mode 100644 bookpipeline/main.go

(limited to 'bookpipeline')

diff --git a/bookpipeline/aws.go b/bookpipeline/aws.go
new file mode 100644
index 0000000..1ac06de
--- /dev/null
+++ b/bookpipeline/aws.go
@@ -0,0 +1,269 @@
+package main
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"os"
+	"regexp"
+	"time"
+
+	"github.com/aws/aws-sdk-go/aws"
+	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/s3"
+	"github.com/aws/aws-sdk-go/service/s3/s3manager"
+	"github.com/aws/aws-sdk-go/service/sqs"
+)
+
+const PreprocPattern = `_bin[0-9].[0-9].png`
+const HeartbeatTime = 60
+
+type awsConn struct {
+	// these need to be set before running Init()
+	region string
+	logger *log.Logger
+
+	// these are used internally
+	sess *session.Session
+        s3svc *s3.S3
+        sqssvc *sqs.SQS
+        downloader *s3manager.Downloader
+	uploader *s3manager.Uploader
+	prequrl, ocrqurl, analysequrl string
+}
+
+func (a *awsConn) Init() error {
+	if a.region == "" {
+		return errors.New("No region set")
+	}
+	if a.logger == nil {
+		return errors.New("No logger set")
+	}
+
+	var err error
+	a.sess, err = session.NewSession(&aws.Config{
+		Region: aws.String(a.region),
+	})
+	if err != nil {
+		return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err))
+	}
+	a.s3svc = s3.New(a.sess)
+	a.sqssvc = sqs.New(a.sess)
+	a.downloader = s3manager.NewDownloader(a.sess)
+	a.uploader = s3manager.NewUploader(a.sess)
+
+        a.logger.Println("Getting preprocess queue URL")
+        result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
+                QueueName: aws.String("rescribepreprocess"),
+        })
+        if err != nil {
+                return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err))
+        }
+        a.prequrl = *result.QueueUrl
+
+        a.logger.Println("Getting OCR queue URL")
+        result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
+                QueueName: aws.String("rescribeocr"),
+        })
+        if err != nil {
+                return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err))
+        }
+        a.ocrqurl = *result.QueueUrl
+
+        a.logger.Println("Getting analyse queue URL")
+        result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
+                QueueName: aws.String("rescribeanalyse"),
+        })
+        if err != nil {
+                return errors.New(fmt.Sprintf("Error getting analyse queue URL: %s", err))
+        }
+        a.analysequrl = *result.QueueUrl
+
+	return nil
+}
+
+func (a *awsConn) CheckQueue(url string) (Qmsg, error) {
+	msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
+		MaxNumberOfMessages: aws.Int64(1),
+		VisibilityTimeout: aws.Int64(HeartbeatTime * 2),
+		WaitTimeSeconds: aws.Int64(20),
+		QueueUrl: &url,
+	})
+	if err != nil {
+		return Qmsg{}, err
+	}
+
+	if len(msgResult.Messages) > 0 {
+		msg := Qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body }
+		a.logger.Println("Message received:", msg.Body)
+		return msg, nil
+	} else {
+		return Qmsg{}, nil
+	}
+}
+
+func (a *awsConn) CheckPreQueue() (Qmsg, error) {
+	a.logger.Println("Checking preprocessing queue for new messages")
+	return a.CheckQueue(a.prequrl)
+}
+
+func (a *awsConn) CheckOCRQueue() (Qmsg, error) {
+	a.logger.Println("Checking OCR queue for new messages")
+	return a.CheckQueue(a.ocrqurl)
+}
+
+func (a *awsConn) CheckAnalyseQueue() (Qmsg, error) {
+	a.logger.Println("Checking analyse queue for new messages")
+	return a.CheckQueue(a.ocrqurl)
+}
+
+func (a *awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error {
+	for _ = range t.C {
+		duration := int64(HeartbeatTime * 2)
+		_, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
+			ReceiptHandle: &msgHandle,
+			QueueUrl: &qurl,
+			VisibilityTimeout: &duration,
+		})
+		if err != nil {
+			return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err))
+		}
+	}
+	return nil
+}
+
+func (a *awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error {
+	a.logger.Println("Starting preprocess queue heartbeat")
+	return a.QueueHeartbeat(t, msgHandle, a.prequrl)
+}
+
+func (a *awsConn) OCRQueueHeartbeat(t *time.Ticker, msgHandle string) error {
+	a.logger.Println("Starting ocr queue heartbeat")
+	return a.QueueHeartbeat(t, msgHandle, a.ocrqurl)
+}
+
+func (a *awsConn) ListObjects(bucket string, prefix string) ([]string, error) {
+	var names []string
+	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
+		Bucket: aws.String(bucket),
+		Prefix: aws.String(prefix),
+	}, func(page *s3.ListObjectsV2Output, last bool) bool {
+		for _, r := range page.Contents {
+			names = append(names, *r.Key)
+		}
+		return true
+	})
+	return names, err
+}
+
+func (a *awsConn) ListToPreprocess(bookname string) ([]string, error) {
+	var names []string
+	preprocessed := regexp.MustCompile(PreprocPattern)
+	objs, err := a.ListObjects("rescribeinprogress", bookname)
+	if err != nil {
+		return names, err
+	}
+	// Filter out any object that looks like it's already been preprocessed
+	for _, n := range objs {
+		if preprocessed.MatchString(n) {
+			a.logger.Println("Skipping item that looks like it has already been processed", n)
+			continue
+		}
+		names = append(names, n)
+	}
+	return names, nil
+}
+
+func (a *awsConn) ListToOCR(bookname string) ([]string, error) {
+	var names []string
+	preprocessed := regexp.MustCompile(PreprocPattern)
+	objs, err := a.ListObjects("rescribeinprogress", bookname)
+	if err != nil {
+		return names, err
+	}
+	// Filter out any object that looks like it hasn't already been preprocessed
+	for _, n := range objs {
+		if ! preprocessed.MatchString(n) {
+			a.logger.Println("Skipping item that looks like it is not preprocessed", n)
+			continue
+		}
+		names = append(names, n)
+	}
+	return names, nil
+}
+
+func (a *awsConn) AddToQueue(url string, msg string) error {
+	_, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{
+		MessageBody: &msg,
+		QueueUrl: &url,
+	})
+	return err
+}
+
+func (a *awsConn) AddToOCRQueue(msg string) error {
+	return a.AddToQueue(a.ocrqurl, msg)
+}
+
+func (a *awsConn) AddToAnalyseQueue(msg string) error {
+	return a.AddToQueue(a.analysequrl, msg)
+}
+
+func (a *awsConn) DelFromQueue(url string, handle string) error {
+	_, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{
+		QueueUrl: &url,
+		ReceiptHandle: &handle,
+	})
+	return err
+}
+
+func (a *awsConn) DelFromPreQueue(handle string) error {
+	return a.DelFromQueue(a.prequrl, handle)
+}
+
+func (a *awsConn) DelFromOCRQueue(handle string) error {
+	return a.DelFromQueue(a.ocrqurl, handle)
+}
+
+func (a *awsConn) Download(bucket string, key string, path string) error {
+	f, err := os.Create(path)
+	if err != nil {
+		return err
+	}
+	defer f.Close()
+
+	_, err = a.downloader.Download(f,
+		&s3.GetObjectInput{
+			Bucket: aws.String(bucket),
+			Key: &key,
+	})
+	return err
+}
+
+func (a *awsConn) DownloadFromInProgress(key string, path string) error {
+	a.logger.Println("Downloading", key)
+	return a.Download("rescribeinprogress", key, path)
+}
+
+func (a *awsConn) Upload(bucket string, key string, path string) error {
+	file, err := os.Open(path)
+	if err != nil {
+		log.Fatalln("Failed to open file", path, err)
+	}
+	defer file.Close()
+
+	_, err = a.uploader.Upload(&s3manager.UploadInput{
+		Bucket: aws.String(bucket),
+		Key:    aws.String(key),
+		Body:   file,
+	})
+	return err
+}
+
+func (a *awsConn) UploadToInProgress(key string, path string) error {
+	a.logger.Println("Uploading", path)
+	return a.Upload("rescribeinprogress", key, path)
+}
+
+func (a *awsConn) Logger() *log.Logger {
+	return a.logger
+}
diff --git a/bookpipeline/main.go b/bookpipeline/main.go
new file mode 100644
index 0000000..c1fb547
--- /dev/null
+++ b/bookpipeline/main.go
@@ -0,0 +1,343 @@
+package main
+// TODO: have logs go somewhere useful, like email
+// TODO: check if images are prebinarised and if so skip multiple binarisation
+
+import (
+	"errors"
+	"fmt"
+	"log"
+	"os"
+	"os/exec"
+	"path/filepath"
+	"strings"
+	"time"
+
+	"rescribe.xyz/go.git/preproc"
+)
+
+const usage = `Usage: bookpipeline [-v]
+
+Watches the preprocess, ocr and analyse queues for book names. When
+one is found this general process is followed:
+
+- The book name is hidden from the queue, and a 'heartbeat' is
+  started which keeps it hidden (this will time out after 2 minutes
+  if the program is terminated)
+- The necessary files from bookname/ are downloaded
+- The files are processed
+- The resulting files are uploaded to bookname/
+- The heartbeat is stopped
+- The book name is removed from the queue it was taken from, and
+  added to the next queue for future processing
+
+-v  verbose
+`
+
+const training = "rescribealphav5" // TODO: allow to set on cmdline
+
+// null writer to enable non-verbose logging to be discarded
+type NullWriter bool
+func (w NullWriter) Write(p []byte) (n int, err error) {
+	return len(p), nil
+}
+
+const PauseBetweenChecks = 3 * time.Minute
+
+type Clouder interface {
+	Init() error
+	ListObjects(bucket string, prefix string) ([]string, error)
+	Download(bucket string, key string, fn string) error
+	Upload(bucket string, key string, path string) error
+	CheckQueue(url string) (Qmsg, error)
+	AddToQueue(url string, msg string) error
+	DelFromQueue(url string, handle string) error
+	QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error
+}
+
+type Pipeliner interface {
+	Clouder
+	ListToPreprocess(bookname string) ([]string, error)
+	ListToOCR(bookname string) ([]string, error)
+	DownloadFromInProgress(key string, fn string) error
+	UploadToInProgress(key string, path string) error
+	CheckPreQueue() (Qmsg, error)
+	CheckOCRQueue() (Qmsg, error)
+	CheckAnalyseQueue() (Qmsg, error)
+	AddToOCRQueue(msg string) error
+	AddToAnalyseQueue(msg string) error
+	DelFromPreQueue(handle string) error
+	DelFromOCRQueue(handle string) error
+	PreQueueHeartbeat(t *time.Ticker, msgHandle string) error
+	OCRQueueHeartbeat(t *time.Ticker, msgHandle string) error
+	Logger() *log.Logger
+}
+
+type Qmsg struct {
+	Handle, Body string
+}
+
+func download(dl chan string, process chan string, conn Pipeliner, dir string, errc chan error) {
+	for key := range dl {
+		fn := filepath.Join(dir, filepath.Base(key))
+		err := conn.DownloadFromInProgress(key, fn)
+		if err != nil {
+			close(process)
+			errc <- err
+			return
+		}
+		process <- fn
+	}
+	close(process)
+}
+
+func up(c chan string, done chan bool, conn Pipeliner, bookname string, errc chan error) {
+	for path := range c {
+		name := filepath.Base(path)
+		key := filepath.Join(bookname, name)
+		err := conn.UploadToInProgress(key, path)
+		if err != nil {
+			errc <- err
+			return
+		}
+	}
+
+	done <- true
+}
+
+func preprocess(pre chan string, up chan string, logger *log.Logger, errc chan error) {
+	for path := range pre {
+		logger.Println("Preprocessing", path)
+		done, err := preproc.PreProcMulti(path, []float64{0.1, 0.2, 0.4, 0.5}, "binary", 0, true, 5, 30)
+		if err != nil {
+			close(up)
+			errc <- err
+			return
+		}
+		for _, p := range done {
+			up <- p
+		}
+	}
+	close(up)
+}
+
+// TODO: use Tesseract API rather than calling the executable
+func ocr(toocr chan string, up chan string, logger *log.Logger, errc chan error) {
+	for path := range toocr {
+		logger.Println("OCRing", path)
+		name := strings.Replace(path, ".png", "", 1) // TODO: handle any file extension
+		cmd := exec.Command("tesseract", "-l", training, path, name, "hocr")
+		err := cmd.Run()
+		if err != nil {
+			close(up)
+			errc <- errors.New(fmt.Sprintf("Error ocring %s: %s", path, err))
+			return
+		}
+		up <- name + ".hocr"
+	}
+	close(up)
+}
+
+func preprocBook(msg Qmsg, conn Pipeliner) error {
+	bookname := msg.Body
+
+	t := time.NewTicker(HeartbeatTime * time.Second)
+	go conn.PreQueueHeartbeat(t, msg.Handle)
+
+	d := filepath.Join(os.TempDir(), bookname)
+	err := os.MkdirAll(d, 0755)
+	if err != nil {
+		t.Stop()
+		return errors.New(fmt.Sprintf("Failed to create directory %s: %s", d, err))
+	}
+
+	dl := make(chan string)
+	pre := make(chan string)
+	upc := make(chan string)
+	done := make(chan bool)
+	errc := make(chan error)
+
+	// these functions will do their jobs when their channels have data
+	go download(dl, pre, conn, d, errc)
+	go preprocess(pre, upc, conn.Logger(), errc)
+	go up(upc, done, conn, bookname, errc)
+
+	conn.Logger().Println("Getting list of objects to download")
+	todl, err := conn.ListToPreprocess(bookname)
+	if err != nil {
+		t.Stop()
+		_ = os.RemoveAll(d)
+		return errors.New(fmt.Sprintf("Failed to get list of files for book %s: %s", bookname, err))
+	}
+	for _, d := range todl {
+		dl <- d
+	}
+	close(dl)
+
+	// wait for either the done or errc channel to be sent to
+	select {
+		case err = <-errc:
+			t.Stop()
+			_ = os.RemoveAll(d)
+			return err
+		case <-done:
+	}
+
+	conn.Logger().Println("Sending", bookname, "to OCR queue")
+	err = conn.AddToOCRQueue(bookname)
+	if err != nil {
+		t.Stop()
+		_ = os.RemoveAll(d)
+		return errors.New(fmt.Sprintf("Error adding to ocr queue %s: %s", bookname, err))
+	}
+
+	t.Stop()
+
+	conn.Logger().Println("Deleting original message from preprocessing queue")
+	err = conn.DelFromPreQueue(msg.Handle)
+	if err != nil {
+		_ = os.RemoveAll(d)
+		return errors.New(fmt.Sprintf("Error deleting message from preprocessing queue: %s", err))
+	}
+
+	err = os.RemoveAll(d)
+	if err != nil {
+		return errors.New(fmt.Sprintf("Failed to remove directory %s: %s", d, err))
+	}
+
+	return nil
+}
+
+// TODO: this is very similar to preprocBook; try to at least mostly merge them
+func ocrBook(msg Qmsg, conn Pipeliner) error {
+	bookname := msg.Body
+
+	t := time.NewTicker(HeartbeatTime * time.Second)
+	go conn.OCRQueueHeartbeat(t, msg.Handle)
+
+	d := filepath.Join(os.TempDir(), bookname)
+	err := os.MkdirAll(d, 0755)
+	if err != nil {
+		t.Stop()
+		return errors.New(fmt.Sprintf("Failed to create directory %s: %s", d, err))
+	}
+
+	dl := make(chan string)
+	ocrc := make(chan string)
+	upc := make(chan string)
+	done := make(chan bool)
+	errc := make(chan error)
+
+	// these functions will do their jobs when their channels have data
+	go download(dl, ocrc, conn, d, errc)
+	go ocr(ocrc, upc, conn.Logger(), errc)
+	go up(upc, done, conn, bookname, errc)
+
+	conn.Logger().Println("Getting list of objects to download")
+	todl, err := conn.ListToOCR(bookname)
+	if err != nil {
+		t.Stop()
+		_ = os.RemoveAll(d)
+		return errors.New(fmt.Sprintf("Failed to get list of files for book %s: %s", bookname, err))
+	}
+	for _, a := range todl {
+		dl <- a
+	}
+	close(dl)
+
+	// wait for either the done or errc channel to be sent to
+	select {
+		case err = <-errc:
+			t.Stop()
+			_ = os.RemoveAll(d)
+			return err
+		case <-done:
+	}
+
+	conn.Logger().Println("Sending", bookname, "to analyse queue")
+	err = conn.AddToAnalyseQueue(bookname)
+	if err != nil {
+		t.Stop()
+		_ = os.RemoveAll(d)
+		return errors.New(fmt.Sprintf("Error adding to analyse queue %s: %s", bookname, err))
+	}
+
+	t.Stop()
+
+	conn.Logger().Println("Deleting original message from OCR queue")
+	err = conn.DelFromOCRQueue(msg.Handle)
+	if err != nil {
+		_ = os.RemoveAll(d)
+		return errors.New(fmt.Sprintf("Error deleting message from OCR queue: %s", err))
+	}
+
+	err = os.RemoveAll(d)
+	if err != nil {
+		return errors.New(fmt.Sprintf("Failed to remove directory %s: %s", d, err))
+	}
+
+	return nil
+}
+
+func main() {
+	var verboselog *log.Logger
+	if len(os.Args) > 1 {
+		if os.Args[1] == "-v" {
+			verboselog = log.New(os.Stdout, "", log.LstdFlags)
+		} else {
+			log.Fatal(usage)
+		}
+	} else {
+		var n NullWriter
+		verboselog = log.New(n, "", log.LstdFlags)
+	}
+
+	var conn Pipeliner
+	conn = &awsConn{ region: "eu-west-2", logger: verboselog }
+
+	verboselog.Println("Setting up AWS session")
+	err := conn.Init()
+	if err != nil {
+		log.Fatalln("Error setting up cloud connection:", err)
+	}
+	verboselog.Println("Finished setting up AWS session")
+
+	var checkPreQueue <-chan time.Time
+	var checkOCRQueue <-chan time.Time
+	checkPreQueue = time.After(0)
+	checkOCRQueue = time.After(0)
+
+	for {
+		select {
+		case <- checkPreQueue:
+			msg, err := conn.CheckPreQueue()
+			checkPreQueue = time.After(PauseBetweenChecks)
+			if err != nil {
+				log.Println("Error checking preprocess queue", err)
+				continue
+			}
+			if msg.Handle == "" {
+				verboselog.Println("No message received on preprocess queue, sleeping")
+				continue
+			}
+			err = preprocBook(msg, conn)
+			if err != nil {
+				log.Println("Error during preprocess", err)
+			}
+		case <- checkOCRQueue:
+			msg, err := conn.CheckOCRQueue()
+			checkOCRQueue = time.After(PauseBetweenChecks)
+			if err != nil {
+				log.Println("Error checking OCR queue", err)
+				continue
+			}
+			if msg.Handle == "" {
+				verboselog.Println("No message received on OCR queue, sleeping")
+				continue
+			}
+			err = ocrBook(msg, conn)
+			if err != nil {
+				log.Println("Error during OCR process", err)
+			}
+		}
+	}
+}
-- 
cgit v1.2.1-24-ge1ad