diff options
Diffstat (limited to 'bookpipeline')
| -rw-r--r-- | bookpipeline/aws.go | 269 | ||||
| -rw-r--r-- | bookpipeline/main.go | 343 | 
2 files changed, 612 insertions, 0 deletions
| diff --git a/bookpipeline/aws.go b/bookpipeline/aws.go new file mode 100644 index 0000000..1ac06de --- /dev/null +++ b/bookpipeline/aws.go @@ -0,0 +1,269 @@ +package main + +import ( +	"errors" +	"fmt" +	"log" +	"os" +	"regexp" +	"time" + +	"github.com/aws/aws-sdk-go/aws" +	"github.com/aws/aws-sdk-go/aws/session" +	"github.com/aws/aws-sdk-go/service/s3" +	"github.com/aws/aws-sdk-go/service/s3/s3manager" +	"github.com/aws/aws-sdk-go/service/sqs" +) + +const PreprocPattern = `_bin[0-9].[0-9].png` +const HeartbeatTime = 60 + +type awsConn struct { +	// these need to be set before running Init() +	region string +	logger *log.Logger + +	// these are used internally +	sess *session.Session +        s3svc *s3.S3 +        sqssvc *sqs.SQS +        downloader *s3manager.Downloader +	uploader *s3manager.Uploader +	prequrl, ocrqurl, analysequrl string +} + +func (a *awsConn) Init() error { +	if a.region == "" { +		return errors.New("No region set") +	} +	if a.logger == nil { +		return errors.New("No logger set") +	} + +	var err error +	a.sess, err = session.NewSession(&aws.Config{ +		Region: aws.String(a.region), +	}) +	if err != nil { +		return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err)) +	} +	a.s3svc = s3.New(a.sess) +	a.sqssvc = sqs.New(a.sess) +	a.downloader = s3manager.NewDownloader(a.sess) +	a.uploader = s3manager.NewUploader(a.sess) + +        a.logger.Println("Getting preprocess queue URL") +        result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ +                QueueName: aws.String("rescribepreprocess"), +        }) +        if err != nil { +                return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err)) +        } +        a.prequrl = *result.QueueUrl + +        a.logger.Println("Getting OCR queue URL") +        result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ +                QueueName: aws.String("rescribeocr"), +        }) +        if err != nil { +                return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err)) +        } +        a.ocrqurl = *result.QueueUrl + +        a.logger.Println("Getting analyse queue URL") +        result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ +                QueueName: aws.String("rescribeanalyse"), +        }) +        if err != nil { +                return errors.New(fmt.Sprintf("Error getting analyse queue URL: %s", err)) +        } +        a.analysequrl = *result.QueueUrl + +	return nil +} + +func (a *awsConn) CheckQueue(url string) (Qmsg, error) { +	msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ +		MaxNumberOfMessages: aws.Int64(1), +		VisibilityTimeout: aws.Int64(HeartbeatTime * 2), +		WaitTimeSeconds: aws.Int64(20), +		QueueUrl: &url, +	}) +	if err != nil { +		return Qmsg{}, err +	} + +	if len(msgResult.Messages) > 0 { +		msg := Qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body } +		a.logger.Println("Message received:", msg.Body) +		return msg, nil +	} else { +		return Qmsg{}, nil +	} +} + +func (a *awsConn) CheckPreQueue() (Qmsg, error) { +	a.logger.Println("Checking preprocessing queue for new messages") +	return a.CheckQueue(a.prequrl) +} + +func (a *awsConn) CheckOCRQueue() (Qmsg, error) { +	a.logger.Println("Checking OCR queue for new messages") +	return a.CheckQueue(a.ocrqurl) +} + +func (a *awsConn) CheckAnalyseQueue() (Qmsg, error) { +	a.logger.Println("Checking analyse queue for new messages") +	return a.CheckQueue(a.ocrqurl) +} + +func (a *awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error { +	for _ = range t.C { +		duration := int64(HeartbeatTime * 2) +		_, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ +			ReceiptHandle: &msgHandle, +			QueueUrl: &qurl, +			VisibilityTimeout: &duration, +		}) +		if err != nil { +			return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err)) +		} +	} +	return nil +} + +func (a *awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error { +	a.logger.Println("Starting preprocess queue heartbeat") +	return a.QueueHeartbeat(t, msgHandle, a.prequrl) +} + +func (a *awsConn) OCRQueueHeartbeat(t *time.Ticker, msgHandle string) error { +	a.logger.Println("Starting ocr queue heartbeat") +	return a.QueueHeartbeat(t, msgHandle, a.ocrqurl) +} + +func (a *awsConn) ListObjects(bucket string, prefix string) ([]string, error) { +	var names []string +	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ +		Bucket: aws.String(bucket), +		Prefix: aws.String(prefix), +	}, func(page *s3.ListObjectsV2Output, last bool) bool { +		for _, r := range page.Contents { +			names = append(names, *r.Key) +		} +		return true +	}) +	return names, err +} + +func (a *awsConn) ListToPreprocess(bookname string) ([]string, error) { +	var names []string +	preprocessed := regexp.MustCompile(PreprocPattern) +	objs, err := a.ListObjects("rescribeinprogress", bookname) +	if err != nil { +		return names, err +	} +	// Filter out any object that looks like it's already been preprocessed +	for _, n := range objs { +		if preprocessed.MatchString(n) { +			a.logger.Println("Skipping item that looks like it has already been processed", n) +			continue +		} +		names = append(names, n) +	} +	return names, nil +} + +func (a *awsConn) ListToOCR(bookname string) ([]string, error) { +	var names []string +	preprocessed := regexp.MustCompile(PreprocPattern) +	objs, err := a.ListObjects("rescribeinprogress", bookname) +	if err != nil { +		return names, err +	} +	// Filter out any object that looks like it hasn't already been preprocessed +	for _, n := range objs { +		if ! preprocessed.MatchString(n) { +			a.logger.Println("Skipping item that looks like it is not preprocessed", n) +			continue +		} +		names = append(names, n) +	} +	return names, nil +} + +func (a *awsConn) AddToQueue(url string, msg string) error { +	_, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{ +		MessageBody: &msg, +		QueueUrl: &url, +	}) +	return err +} + +func (a *awsConn) AddToOCRQueue(msg string) error { +	return a.AddToQueue(a.ocrqurl, msg) +} + +func (a *awsConn) AddToAnalyseQueue(msg string) error { +	return a.AddToQueue(a.analysequrl, msg) +} + +func (a *awsConn) DelFromQueue(url string, handle string) error { +	_, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ +		QueueUrl: &url, +		ReceiptHandle: &handle, +	}) +	return err +} + +func (a *awsConn) DelFromPreQueue(handle string) error { +	return a.DelFromQueue(a.prequrl, handle) +} + +func (a *awsConn) DelFromOCRQueue(handle string) error { +	return a.DelFromQueue(a.ocrqurl, handle) +} + +func (a *awsConn) Download(bucket string, key string, path string) error { +	f, err := os.Create(path) +	if err != nil { +		return err +	} +	defer f.Close() + +	_, err = a.downloader.Download(f, +		&s3.GetObjectInput{ +			Bucket: aws.String(bucket), +			Key: &key, +	}) +	return err +} + +func (a *awsConn) DownloadFromInProgress(key string, path string) error { +	a.logger.Println("Downloading", key) +	return a.Download("rescribeinprogress", key, path) +} + +func (a *awsConn) Upload(bucket string, key string, path string) error { +	file, err := os.Open(path) +	if err != nil { +		log.Fatalln("Failed to open file", path, err) +	} +	defer file.Close() + +	_, err = a.uploader.Upload(&s3manager.UploadInput{ +		Bucket: aws.String(bucket), +		Key:    aws.String(key), +		Body:   file, +	}) +	return err +} + +func (a *awsConn) UploadToInProgress(key string, path string) error { +	a.logger.Println("Uploading", path) +	return a.Upload("rescribeinprogress", key, path) +} + +func (a *awsConn) Logger() *log.Logger { +	return a.logger +} diff --git a/bookpipeline/main.go b/bookpipeline/main.go new file mode 100644 index 0000000..c1fb547 --- /dev/null +++ b/bookpipeline/main.go @@ -0,0 +1,343 @@ +package main +// TODO: have logs go somewhere useful, like email +// TODO: check if images are prebinarised and if so skip multiple binarisation + +import ( +	"errors" +	"fmt" +	"log" +	"os" +	"os/exec" +	"path/filepath" +	"strings" +	"time" + +	"rescribe.xyz/go.git/preproc" +) + +const usage = `Usage: bookpipeline [-v] + +Watches the preprocess, ocr and analyse queues for book names. When +one is found this general process is followed: + +- The book name is hidden from the queue, and a 'heartbeat' is +  started which keeps it hidden (this will time out after 2 minutes +  if the program is terminated) +- The necessary files from bookname/ are downloaded +- The files are processed +- The resulting files are uploaded to bookname/ +- The heartbeat is stopped +- The book name is removed from the queue it was taken from, and +  added to the next queue for future processing + +-v  verbose +` + +const training = "rescribealphav5" // TODO: allow to set on cmdline + +// null writer to enable non-verbose logging to be discarded +type NullWriter bool +func (w NullWriter) Write(p []byte) (n int, err error) { +	return len(p), nil +} + +const PauseBetweenChecks = 3 * time.Minute + +type Clouder interface { +	Init() error +	ListObjects(bucket string, prefix string) ([]string, error) +	Download(bucket string, key string, fn string) error +	Upload(bucket string, key string, path string) error +	CheckQueue(url string) (Qmsg, error) +	AddToQueue(url string, msg string) error +	DelFromQueue(url string, handle string) error +	QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error +} + +type Pipeliner interface { +	Clouder +	ListToPreprocess(bookname string) ([]string, error) +	ListToOCR(bookname string) ([]string, error) +	DownloadFromInProgress(key string, fn string) error +	UploadToInProgress(key string, path string) error +	CheckPreQueue() (Qmsg, error) +	CheckOCRQueue() (Qmsg, error) +	CheckAnalyseQueue() (Qmsg, error) +	AddToOCRQueue(msg string) error +	AddToAnalyseQueue(msg string) error +	DelFromPreQueue(handle string) error +	DelFromOCRQueue(handle string) error +	PreQueueHeartbeat(t *time.Ticker, msgHandle string) error +	OCRQueueHeartbeat(t *time.Ticker, msgHandle string) error +	Logger() *log.Logger +} + +type Qmsg struct { +	Handle, Body string +} + +func download(dl chan string, process chan string, conn Pipeliner, dir string, errc chan error) { +	for key := range dl { +		fn := filepath.Join(dir, filepath.Base(key)) +		err := conn.DownloadFromInProgress(key, fn) +		if err != nil { +			close(process) +			errc <- err +			return +		} +		process <- fn +	} +	close(process) +} + +func up(c chan string, done chan bool, conn Pipeliner, bookname string, errc chan error) { +	for path := range c { +		name := filepath.Base(path) +		key := filepath.Join(bookname, name) +		err := conn.UploadToInProgress(key, path) +		if err != nil { +			errc <- err +			return +		} +	} + +	done <- true +} + +func preprocess(pre chan string, up chan string, logger *log.Logger, errc chan error) { +	for path := range pre { +		logger.Println("Preprocessing", path) +		done, err := preproc.PreProcMulti(path, []float64{0.1, 0.2, 0.4, 0.5}, "binary", 0, true, 5, 30) +		if err != nil { +			close(up) +			errc <- err +			return +		} +		for _, p := range done { +			up <- p +		} +	} +	close(up) +} + +// TODO: use Tesseract API rather than calling the executable +func ocr(toocr chan string, up chan string, logger *log.Logger, errc chan error) { +	for path := range toocr { +		logger.Println("OCRing", path) +		name := strings.Replace(path, ".png", "", 1) // TODO: handle any file extension +		cmd := exec.Command("tesseract", "-l", training, path, name, "hocr") +		err := cmd.Run() +		if err != nil { +			close(up) +			errc <- errors.New(fmt.Sprintf("Error ocring %s: %s", path, err)) +			return +		} +		up <- name + ".hocr" +	} +	close(up) +} + +func preprocBook(msg Qmsg, conn Pipeliner) error { +	bookname := msg.Body + +	t := time.NewTicker(HeartbeatTime * time.Second) +	go conn.PreQueueHeartbeat(t, msg.Handle) + +	d := filepath.Join(os.TempDir(), bookname) +	err := os.MkdirAll(d, 0755) +	if err != nil { +		t.Stop() +		return errors.New(fmt.Sprintf("Failed to create directory %s: %s", d, err)) +	} + +	dl := make(chan string) +	pre := make(chan string) +	upc := make(chan string) +	done := make(chan bool) +	errc := make(chan error) + +	// these functions will do their jobs when their channels have data +	go download(dl, pre, conn, d, errc) +	go preprocess(pre, upc, conn.Logger(), errc) +	go up(upc, done, conn, bookname, errc) + +	conn.Logger().Println("Getting list of objects to download") +	todl, err := conn.ListToPreprocess(bookname) +	if err != nil { +		t.Stop() +		_ = os.RemoveAll(d) +		return errors.New(fmt.Sprintf("Failed to get list of files for book %s: %s", bookname, err)) +	} +	for _, d := range todl { +		dl <- d +	} +	close(dl) + +	// wait for either the done or errc channel to be sent to +	select { +		case err = <-errc: +			t.Stop() +			_ = os.RemoveAll(d) +			return err +		case <-done: +	} + +	conn.Logger().Println("Sending", bookname, "to OCR queue") +	err = conn.AddToOCRQueue(bookname) +	if err != nil { +		t.Stop() +		_ = os.RemoveAll(d) +		return errors.New(fmt.Sprintf("Error adding to ocr queue %s: %s", bookname, err)) +	} + +	t.Stop() + +	conn.Logger().Println("Deleting original message from preprocessing queue") +	err = conn.DelFromPreQueue(msg.Handle) +	if err != nil { +		_ = os.RemoveAll(d) +		return errors.New(fmt.Sprintf("Error deleting message from preprocessing queue: %s", err)) +	} + +	err = os.RemoveAll(d) +	if err != nil { +		return errors.New(fmt.Sprintf("Failed to remove directory %s: %s", d, err)) +	} + +	return nil +} + +// TODO: this is very similar to preprocBook; try to at least mostly merge them +func ocrBook(msg Qmsg, conn Pipeliner) error { +	bookname := msg.Body + +	t := time.NewTicker(HeartbeatTime * time.Second) +	go conn.OCRQueueHeartbeat(t, msg.Handle) + +	d := filepath.Join(os.TempDir(), bookname) +	err := os.MkdirAll(d, 0755) +	if err != nil { +		t.Stop() +		return errors.New(fmt.Sprintf("Failed to create directory %s: %s", d, err)) +	} + +	dl := make(chan string) +	ocrc := make(chan string) +	upc := make(chan string) +	done := make(chan bool) +	errc := make(chan error) + +	// these functions will do their jobs when their channels have data +	go download(dl, ocrc, conn, d, errc) +	go ocr(ocrc, upc, conn.Logger(), errc) +	go up(upc, done, conn, bookname, errc) + +	conn.Logger().Println("Getting list of objects to download") +	todl, err := conn.ListToOCR(bookname) +	if err != nil { +		t.Stop() +		_ = os.RemoveAll(d) +		return errors.New(fmt.Sprintf("Failed to get list of files for book %s: %s", bookname, err)) +	} +	for _, a := range todl { +		dl <- a +	} +	close(dl) + +	// wait for either the done or errc channel to be sent to +	select { +		case err = <-errc: +			t.Stop() +			_ = os.RemoveAll(d) +			return err +		case <-done: +	} + +	conn.Logger().Println("Sending", bookname, "to analyse queue") +	err = conn.AddToAnalyseQueue(bookname) +	if err != nil { +		t.Stop() +		_ = os.RemoveAll(d) +		return errors.New(fmt.Sprintf("Error adding to analyse queue %s: %s", bookname, err)) +	} + +	t.Stop() + +	conn.Logger().Println("Deleting original message from OCR queue") +	err = conn.DelFromOCRQueue(msg.Handle) +	if err != nil { +		_ = os.RemoveAll(d) +		return errors.New(fmt.Sprintf("Error deleting message from OCR queue: %s", err)) +	} + +	err = os.RemoveAll(d) +	if err != nil { +		return errors.New(fmt.Sprintf("Failed to remove directory %s: %s", d, err)) +	} + +	return nil +} + +func main() { +	var verboselog *log.Logger +	if len(os.Args) > 1 { +		if os.Args[1] == "-v" { +			verboselog = log.New(os.Stdout, "", log.LstdFlags) +		} else { +			log.Fatal(usage) +		} +	} else { +		var n NullWriter +		verboselog = log.New(n, "", log.LstdFlags) +	} + +	var conn Pipeliner +	conn = &awsConn{ region: "eu-west-2", logger: verboselog } + +	verboselog.Println("Setting up AWS session") +	err := conn.Init() +	if err != nil { +		log.Fatalln("Error setting up cloud connection:", err) +	} +	verboselog.Println("Finished setting up AWS session") + +	var checkPreQueue <-chan time.Time +	var checkOCRQueue <-chan time.Time +	checkPreQueue = time.After(0) +	checkOCRQueue = time.After(0) + +	for { +		select { +		case <- checkPreQueue: +			msg, err := conn.CheckPreQueue() +			checkPreQueue = time.After(PauseBetweenChecks) +			if err != nil { +				log.Println("Error checking preprocess queue", err) +				continue +			} +			if msg.Handle == "" { +				verboselog.Println("No message received on preprocess queue, sleeping") +				continue +			} +			err = preprocBook(msg, conn) +			if err != nil { +				log.Println("Error during preprocess", err) +			} +		case <- checkOCRQueue: +			msg, err := conn.CheckOCRQueue() +			checkOCRQueue = time.After(PauseBetweenChecks) +			if err != nil { +				log.Println("Error checking OCR queue", err) +				continue +			} +			if msg.Handle == "" { +				verboselog.Println("No message received on OCR queue, sleeping") +				continue +			} +			err = ocrBook(msg, conn) +			if err != nil { +				log.Println("Error during OCR process", err) +			} +		} +	} +} | 
