diff options
| -rw-r--r-- | pipelinepreprocess/main.go | 352 | 
1 files changed, 250 insertions, 102 deletions
| diff --git a/pipelinepreprocess/main.go b/pipelinepreprocess/main.go index 6c58d98..c075fa9 100644 --- a/pipelinepreprocess/main.go +++ b/pipelinepreprocess/main.go @@ -6,6 +6,8 @@ package main  // TODO: check if images are prebinarised and if so skip multiple binarisation  import ( +	"errors" +	"fmt"  	"log"  	"os"  	"path/filepath" @@ -28,7 +30,8 @@ type NullWriter bool  func (w NullWriter) Write(p []byte) (n int, err error) {  	return len(p), nil  } -var verboselog *log.Logger + +var alreadydone *regexp.Regexp  const HeartbeatTime = 60  const PauseBetweenChecks = 60 * time.Second @@ -44,20 +47,225 @@ const PreprocPattern = `_bin[0-9].[0-9].png`  //  // TODO: consider having the download etc functions return a channel like a generator, like in rob pike's talk -func download(dl chan string, pre chan string, downloader *s3manager.Downloader, dir string) { -	for key := range dl { -		verboselog.Println("Downloading", key) -		fn := filepath.Join(dir, filepath.Base(key)) -		f, err := os.Create(fn) +type Clouder interface { +	Init() error +	ListObjects(bucket string, prefix string, names chan string) error +	Download(bucket string, key string, fn string) error +	Upload(bucket string, key string, path string) error +	CheckQueue(url string) (qmsg, error) +	AddToQueue(url string, msg string) error +	DelFromQueue(url string, handle string) error +	QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error +} + +type Pipeliner interface { +	Clouder +	ListInProgress(bookname string, names chan string) error +	DownloadFromInProgress(key string, fn string) error +	UploadToInProgress(key string, path string) error +	CheckPreQueue() (qmsg, error) +	AddToOCRQueue(msg string) error +	DelFromPreQueue(handle string) error +	PreQueueHeartbeat(t *time.Ticker, msgHandle string) error +} + +type qmsg struct { +	Handle, Body string +} + +type awsConn struct { +	// these need to be set before running Init() +	region string +	logger *log.Logger + +	// these are used internally +	sess *session.Session +        s3svc *s3.S3 +        sqssvc *sqs.SQS +        downloader *s3manager.Downloader +	uploader *s3manager.Uploader +	prequrl, ocrqurl string +} + +func (a awsConn) Init() error { +	if a.region == "" { +		return errors.New("No region set") +	} +	if a.logger == nil { +		return errors.New("No logger set") +	} + +	var err error +	a.sess, err = session.NewSession(&aws.Config{ +		Region: aws.String(a.region), +	}) +	if err != nil { +		return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err)) +	} +	a.s3svc = s3.New(a.sess) +	a.sqssvc = sqs.New(a.sess) +	a.downloader = s3manager.NewDownloader(a.sess) +	a.uploader = s3manager.NewUploader(a.sess) + +        a.logger.Println("Getting preprocess queue URL") +        result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ +                QueueName: aws.String("rescribepreprocess"), +        }) +        if err != nil { +                return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err)) +        } +        a.prequrl = *result.QueueUrl +        a.logger.Println("preprocess queue URL", a.prequrl) + +        a.logger.Println("Getting OCR queue URL") +        result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ +                QueueName: aws.String("rescribeocr"), +        }) +        if err != nil { +                return errors.New(fmt.Sprintf("Error getting OCR queue URL: %s", err)) +        } +        a.ocrqurl = *result.QueueUrl +	return nil +} + +func (a awsConn) CheckQueue(url string) (qmsg, error) { +	msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ +		MaxNumberOfMessages: aws.Int64(1), +		VisibilityTimeout: aws.Int64(HeartbeatTime * 2), +		WaitTimeSeconds: aws.Int64(20), +		QueueUrl: &url, +	}) +	if err != nil { +		return qmsg{}, err +	} + +	if len(msgResult.Messages) > 0 { +		msg := qmsg{ Handle: *msgResult.Messages[0].ReceiptHandle, Body: *msgResult.Messages[0].Body } +		a.logger.Println("Message received:", msg.Body) +		return msg, nil +	} else { +		return qmsg{}, nil +	} +} + +func (a awsConn) CheckPreQueue() (qmsg, error) { +	a.logger.Println("Checking preprocessing queue for new messages:", a.prequrl) +	return a.CheckQueue(a.prequrl) +} + +func (a awsConn) QueueHeartbeat(t *time.Ticker, msgHandle string, qurl string) error { +	for _ = range t.C { +		duration := int64(HeartbeatTime * 2) +		_, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{ +			ReceiptHandle: &msgHandle, +			QueueUrl: &qurl, +			VisibilityTimeout: &duration, +		})  		if err != nil { -			log.Fatalln("Failed to create file", fn, err) +			return errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err))  		} -		defer f.Close() +	} +	return nil +} + +func (a awsConn) PreQueueHeartbeat(t *time.Ticker, msgHandle string) error { +	a.logger.Println("Starting preprocess queue heartbeat for", msgHandle) +	return a.QueueHeartbeat(t, msgHandle, a.prequrl) +} + +func (a awsConn) ListObjects(bucket string, prefix string, names chan string) error { +	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ +		Bucket: aws.String(bucket), +		Prefix: aws.String(prefix), +	}, func(page *s3.ListObjectsV2Output, last bool) bool { +		for _, r := range page.Contents { +			if alreadydone.MatchString(*r.Key) { +				a.logger.Println("Skipping item that looks like it has already been processed", *r.Key) +				continue +			} +			names <- *r.Key +		} +		return true +	}) +	close(names) +	return err +} + +func (a awsConn) ListInProgress(bookname string, names chan string) error { +	return a.ListObjects("rescribeinprogress", bookname, names) +} + +func (a awsConn) AddToQueue(url string, msg string) error { +	_, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{ +		MessageBody: &msg, +		QueueUrl: &url, +	}) +	return err +} + +func (a awsConn) AddToOCRQueue(msg string) error { +	return a.AddToQueue(a.ocrqurl, msg) +} -		_, err = downloader.Download(f, -			&s3.GetObjectInput{ -				Bucket: aws.String("rescribeinprogress"), -				Key: &key }) +func (a awsConn) DelFromQueue(url string, handle string) error { +	_, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ +		QueueUrl: &url, +		ReceiptHandle: &handle, +	}) +	return err +} + +func (a awsConn) DelFromPreQueue(handle string) error { +	return a.DelFromQueue(a.prequrl, handle) +} + +func (a awsConn) Download(bucket string, key string, path string) error { +	f, err := os.Create(path) +	if err != nil { +		return err +	} +	defer f.Close() + +	_, err = a.downloader.Download(f, +		&s3.GetObjectInput{ +			Bucket: aws.String(bucket), +			Key: &key, +	}) +	return err +} + +func (a awsConn) DownloadFromInProgress(key string, path string) error { +	a.logger.Println("Downloading", key) +	return a.Download("rescribeinprogress", key, path) +} + +func (a awsConn) Upload(bucket string, key string, path string) error { +	file, err := os.Open(path) +	if err != nil { +		log.Fatalln("Failed to open file", path, err) +	} +	defer file.Close() + +	_, err = a.uploader.Upload(&s3manager.UploadInput{ +		Bucket: aws.String(bucket), +		Key:    aws.String(key), +		Body:   file, +	}) +	return err +} + +func (a awsConn) UploadToInProgress(key string, path string) error { +	a.logger.Println("Uploading", path) +	return a.Upload("rescribeinprogress", key, path) +} + + + + +func download(dl chan string, pre chan string, conn Pipeliner, dir string) { +	for key := range dl { +		fn := filepath.Join(dir, filepath.Base(key)) +		err := conn.DownloadFromInProgress(key, fn)  		if err != nil {  			log.Fatalln("Failed to download", key, err)  		} @@ -66,9 +274,9 @@ func download(dl chan string, pre chan string, downloader *s3manager.Downloader,  	close(pre)  } -func preprocess(pre chan string, up chan string) { +func preprocess(pre chan string, up chan string, logger *log.Logger) {  	for path := range pre { -		verboselog.Println("Preprocessing", path) +		logger.Println("Preprocessing", path)  		done, err := preproc.PreProcMulti(path, []float64{0.1, 0.2, 0.4, 0.5}, "binary", 0, true, 5, 30)  		if err != nil {  			log.Fatalln("Error preprocessing", path, err) @@ -80,21 +288,11 @@ func preprocess(pre chan string, up chan string) {  	close(up)  } -func up(c chan string, done chan bool, uploader *s3manager.Uploader, bookname string) { +func up(c chan string, done chan bool, conn Pipeliner, bookname string) {  	for path := range c { -		verboselog.Println("Uploading", path)  		name := filepath.Base(path) -		file, err := os.Open(path) -		if err != nil { -			log.Fatalln("Failed to open file", path, err) -		} -		defer file.Close() - -		_, err = uploader.Upload(&s3manager.UploadInput{ -			Bucket: aws.String("rescribeinprogress"), -			Key:    aws.String(filepath.Join(bookname, name)), -			Body:   file, -		}) +		key := filepath.Join(bookname, name) +		err := conn.UploadToInProgress(key, path)  		if err != nil {  			log.Fatalln("Failed to upload", path, err)  		} @@ -118,6 +316,7 @@ func heartbeat(h *time.Ticker, msgHandle string, qurl string, sqssvc *sqs.SQS) {  }  func main() { +	var verboselog *log.Logger  	if len(os.Args) > 1 {  		if os.Args[1] == "-v" {  			verboselog = log.New(os.Stdout, "", log.LstdFlags) @@ -129,65 +328,31 @@ func main() {  		verboselog = log.New(n, "", log.LstdFlags)  	} -	alreadydone := regexp.MustCompile(PreprocPattern) +	alreadydone = regexp.MustCompile(PreprocPattern) -	verboselog.Println("Setting up AWS session") -	sess, err := session.NewSession(&aws.Config{ -		Region: aws.String("eu-west-2"), -	}) -	if err != nil { -		log.Fatalln("Error: failed to set up aws session:", err) -	} -	s3svc := s3.New(sess) -	sqssvc := sqs.New(sess) -	downloader := s3manager.NewDownloader(sess) -	uploader := s3manager.NewUploader(sess) - -	preqname := "rescribepreprocess" -	verboselog.Println("Getting Queue URL for", preqname) -	result, err := sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ -		QueueName: aws.String(preqname), -	}) -	if err != nil { -		log.Fatalln("Error getting queue URL for", preqname, ":", err) -	} -	prequrl := *result.QueueUrl +	var conn Pipeliner +	conn = awsConn{ region: "eu-west-2", logger: verboselog } -	ocrqname := "rescribeocr" -	verboselog.Println("Getting Queue URL for", ocrqname) -	result, err = sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{ -		QueueName: aws.String(ocrqname), -	}) +	verboselog.Println("Setting up AWS session") +	err := conn.Init()  	if err != nil { -		log.Fatalln("Error getting queue URL for", ocrqname, ":", err) +		log.Fatalln("Error setting up cloud connection:", err)  	} -	ocrqurl := *result.QueueUrl  	for { -		verboselog.Println("Checking preprocessing queue for new messages") -		msgResult, err := sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{ -			MaxNumberOfMessages: aws.Int64(1), -			VisibilityTimeout: aws.Int64(HeartbeatTime * 2), -			WaitTimeSeconds: aws.Int64(20), -			QueueUrl: &prequrl, -		}) +		msg, err := conn.CheckPreQueue()  		if err != nil { -			log.Fatalln("Error checking queue", preqname, ":", err) +			log.Fatalln("Error checking preprocess queue", err)  		} - -		var bookname string -		if len(msgResult.Messages) > 0 { -			bookname = *msgResult.Messages[0].Body -			verboselog.Println("Message received:", bookname) -		} else { +		if msg.Handle == "" {  			verboselog.Println("No message received, sleeping")  			time.Sleep(PauseBetweenChecks)  			continue  		} +		bookname := msg.Body -		verboselog.Println("Starting heartbeat every", HeartbeatTime, "seconds")  		t := time.NewTicker(HeartbeatTime * time.Second) -		go heartbeat(t, *msgResult.Messages[0].ReceiptHandle, prequrl, sqssvc) +		go conn.PreQueueHeartbeat(t, msg.Handle)  		d := filepath.Join(os.TempDir(), bookname) @@ -202,49 +367,32 @@ func main() {  		done := make(chan bool) // this is just to communicate when up has finished, so the queues can be updated  		// these functions will do their jobs when their channels have data -		go download(dl, pre, downloader, d) -		go preprocess(pre, upc) -		go up(upc, done, uploader, bookname) +		go download(dl, pre, conn, d) +		go preprocess(pre, upc, verboselog) +		go up(upc, done, conn, bookname)  		verboselog.Println("Getting list of objects to download") -		err = s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{ -			Bucket: aws.String("rescribeinprogress"), -			Prefix: aws.String(bookname), -		}, func(page *s3.ListObjectsV2Output, last bool) bool { -			for _, r := range page.Contents { -				if alreadydone.MatchString(*r.Key) { -					verboselog.Println("Skipping item that looks like it has already been processed", *r.Key) -					continue -				} -				dl <- *r.Key -			} -			return true -		}) -		close(dl) - +		err = conn.ListInProgress(bookname, dl) +		if err != nil { +			log.Fatalln("Failed to get list of files for book", bookname, err) +		}  		// wait for the done channel to be posted to  		<-done -		verboselog.Println("Sending", bookname, "to queue", ocrqurl) -		_, err = sqssvc.SendMessage(&sqs.SendMessageInput{ -			MessageBody: aws.String(bookname), -			QueueUrl: &ocrqurl, -		}) +		verboselog.Println("Sending", bookname, "to OCR queue") +		err = conn.AddToOCRQueue(bookname)  		if err != nil { -			log.Fatalln("Error sending message to queue", ocrqname, ":", err) +			log.Fatalln("Error adding to ocr queue", bookname, err)  		}  		t.Stop() -		verboselog.Println("Deleting original message from queue", prequrl) -		_, err = sqssvc.DeleteMessage(&sqs.DeleteMessageInput{ -			QueueUrl: &prequrl, -			ReceiptHandle: msgResult.Messages[0].ReceiptHandle, -		}) +		verboselog.Println("Deleting original message from preprocessing queue") +		err = conn.DelFromPreQueue(msg.Handle)  		if err != nil { -			log.Fatalln("Error deleting message from queue", preqname, ":", err) +			log.Fatalln("Error deleting message from preprocessing queue", err)  		}  		err = os.RemoveAll(d) | 
