// Copyright 2019 Nick White.
// Use of this source code is governed by the GPLv3
// license that can be found in the LICENSE file.

package bookpipeline

import (
	"errors"
	"fmt"
	"log"
	"os"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/ec2"
	"github.com/aws/aws-sdk-go/service/s3"
	"github.com/aws/aws-sdk-go/service/s3/s3manager"
	"github.com/aws/aws-sdk-go/service/sqs"
)

const defaultAwsRegion = `eu-west-2`

type Qmsg struct {
	Id, Handle, Body string
}

type InstanceDetails struct {
	Id, Name, Ip, Spot, Type, State, LaunchTime string
}

type ObjMeta struct {
	Name string
	Date time.Time
}

// AwsConn contains the necessary things to interact with various AWS
// services in ways useful for the bookpipeline package. It is
// designed to be generic enough to swap in other backends easily.
type AwsConn struct {
	// these should be set before running Init(), or left to defaults
	Region string
	Logger *log.Logger

	sess         *session.Session
	ec2svc       *ec2.EC2
	s3svc        *s3.S3
	sqssvc       *sqs.SQS
	downloader   *s3manager.Downloader
	uploader     *s3manager.Uploader
	wipequrl     string
	prequrl      string
	prenwqurl    string
	ocrpgqurl    string
	analysequrl  string
	testqurl     string
	wipstorageid string
}

// MinimalInit does the bare minimum to initialise aws services
func (a *AwsConn) MinimalInit() error {
	if a.Region == "" {
		a.Region = defaultAwsRegion
	}
	if a.Logger == nil {
		a.Logger = log.New(os.Stdout, "", 0)
	}

	var err error
	a.sess, err = session.NewSession(&aws.Config{
		Region: aws.String(a.Region),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err))
	}
	a.ec2svc = ec2.New(a.sess)
	a.s3svc = s3.New(a.sess)
	a.sqssvc = sqs.New(a.sess)
	a.downloader = s3manager.NewDownloader(a.sess)
	a.uploader = s3manager.NewUploader(a.sess)

	a.wipstorageid = storageWip

	return nil
}

// Init initialises aws services, also finding the urls needed to
// address SQS queues directly.
func (a *AwsConn) Init() error {
	err := a.MinimalInit()
	if err != nil {
		return err
	}

	a.Logger.Println("Getting preprocess queue URL")
	result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
		QueueName: aws.String(queuePreProc),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Error getting preprocess queue URL: %s", err))
	}
	a.prequrl = *result.QueueUrl

	a.Logger.Println("Getting preprocess no wipe queue URL")
	result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
		QueueName: aws.String(queuePreNoWipe),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Error getting preprocess no wipe queue URL: %s", err))
	}
	a.prenwqurl = *result.QueueUrl

	a.Logger.Println("Getting wipeonly queue URL")
	result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
		QueueName: aws.String(queueWipeOnly),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Error getting wipeonly queue URL: %s", err))
	}
	a.wipequrl = *result.QueueUrl

	a.Logger.Println("Getting OCR Page queue URL")
	result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
		QueueName: aws.String(queueOcrPage),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Error getting OCR Page queue URL: %s", err))
	}
	a.ocrpgqurl = *result.QueueUrl

	a.Logger.Println("Getting analyse queue URL")
	result, err = a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
		QueueName: aws.String(queueAnalyse),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Error getting analyse queue URL: %s", err))
	}
	a.analysequrl = *result.QueueUrl

	return nil
}

// TestInit initialises extra aws services needed for running tests.
func (a *AwsConn) TestInit() error {
	a.Logger.Println("Getting test queue URL")
	result, err := a.sqssvc.GetQueueUrl(&sqs.GetQueueUrlInput{
		QueueName: aws.String(queueTest),
	})
	if err != nil {
		return errors.New(fmt.Sprintf("Error getting test queue URL: %s\n", err))
	}
	a.testqurl = *result.QueueUrl
	return nil
}

func (a *AwsConn) CheckQueue(url string, timeout int64) (Qmsg, error) {
	msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
		MaxNumberOfMessages: aws.Int64(1),
		VisibilityTimeout:   &timeout,
		WaitTimeSeconds:     aws.Int64(20),
		QueueUrl:            &url,
	})
	if err != nil {
		return Qmsg{}, err
	}

	if len(msgResult.Messages) > 0 {
		msg := Qmsg{Id: *msgResult.Messages[0].MessageId,
			Handle: *msgResult.Messages[0].ReceiptHandle,
			Body:   *msgResult.Messages[0].Body}
		a.Logger.Println("Message received:", msg.Body)
		return msg, nil
	} else {
		return Qmsg{}, nil
	}
}

func (a *AwsConn) LogAndPurgeQueue(url string) error {
	for {
		msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
			MaxNumberOfMessages: aws.Int64(10),
			VisibilityTimeout:   aws.Int64(300),
			QueueUrl:            &url,
		})
		if err != nil {
			return err
		}

		if len(msgResult.Messages) > 0 {
			for _, m := range msgResult.Messages {
				a.Logger.Println(*m.Body)
				_, err = a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{
					QueueUrl:      &url,
					ReceiptHandle: m.ReceiptHandle,
				})
				if err != nil {
					return err
				}
			}
		} else {
			break
		}
	}
	return nil
}

// LogQueue prints the body of all messages in a queue to the log
func (a *AwsConn) LogQueue(url string) error {
	for {
		msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
			MaxNumberOfMessages: aws.Int64(10),
			VisibilityTimeout:   aws.Int64(300),
			QueueUrl:            &url,
		})
		if err != nil {
			return err
		}

		if len(msgResult.Messages) > 0 {
			for _, m := range msgResult.Messages {
				a.Logger.Println(*m.Body)
			}
		} else {
			break
		}
	}
	return nil
}

// RemovePrefixesFromQueue removes any messages in a queue whose
// body starts with the specified prefix.
func (a *AwsConn) RemovePrefixesFromQueue(url string, prefix string) error {
	for {
		msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
			MaxNumberOfMessages: aws.Int64(10),
			VisibilityTimeout:   aws.Int64(300),
			QueueUrl:            &url,
		})
		if err != nil {
			return err
		}

		if len(msgResult.Messages) > 0 {
			for _, m := range msgResult.Messages {
				if !strings.HasPrefix(*m.Body, prefix) {
					continue
				}
				a.Logger.Printf("Removing %s from queue\n", *m.Body)
				_, err = a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{
					QueueUrl:      &url,
					ReceiptHandle: m.ReceiptHandle,
				})
				if err != nil {
					return err
				}
			}
		} else {
			break
		}
	}
	return nil
}

// QueueHeartbeat updates the visibility timeout of a message. This
// ensures that the message remains "in flight", meaning that it
// cannot be seen by other processes, but if this process fails the
// timeout will expire and it will go back to being available for
// any other process to retrieve and process.
//
// SQS only allows messages to be "in flight" for up to 12 hours, so
// this will detect if the request for an update to visibility timeout
// fails, and if so will attempt to find the message on the queue, and
// return it, as the handle will have changed.
func (a *AwsConn) QueueHeartbeat(msg Qmsg, qurl string, duration int64) (Qmsg, error) {
	_, err := a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
		ReceiptHandle:     &msg.Handle,
		QueueUrl:          &qurl,
		VisibilityTimeout: &duration,
	})
	if err != nil {
		aerr, ok := err.(awserr.Error)

		// Check if the visibility timeout has exceeded the maximum allowed,
		// and if so try to find the message again to get a new handle.
		if ok && aerr.Code() == "InvalidParameterValue" {
			// First try to set the visibilitytimeout to zero to immediately
			// make the message available to receive
			_, _ = a.sqssvc.ChangeMessageVisibility(&sqs.ChangeMessageVisibilityInput{
				ReceiptHandle:     &msg.Handle,
				QueueUrl:          &qurl,
				VisibilityTimeout: aws.Int64(0),
			})

			for i := 0; i < int(duration)*5; i++ {
				msgResult, err := a.sqssvc.ReceiveMessage(&sqs.ReceiveMessageInput{
					MaxNumberOfMessages: aws.Int64(10),
					VisibilityTimeout:   &duration,
					WaitTimeSeconds:     aws.Int64(1),
					QueueUrl:            &qurl,
				})
				if err != nil {
					return Qmsg{}, errors.New(fmt.Sprintf("Heartbeat error looking for message to update heartbeat: %s", err))
				}
				for _, m := range msgResult.Messages {
					if *m.MessageId == msg.Id {
						return Qmsg{
							Id:     *m.MessageId,
							Handle: *m.ReceiptHandle,
							Body:   *m.Body,
						}, nil
					}
				}
				// Wait a second before trying again if the ReceiveMessage
				// call succeeded but didn't contain our message (otherwise
				// the WaitTimeSeconds will have applied and we will already
				// have waited a second)
				if len(msgResult.Messages) > 0 {
					time.Sleep(time.Second)
				}
			}
			return Qmsg{}, errors.New("Heartbeat error failed to find message to update heartbeat")
		} else {
			return Qmsg{}, errors.New(fmt.Sprintf("Heartbeat error updating queue duration: %s", err))
		}
	}
	return Qmsg{}, nil
}

// GetQueueDetails gets the number of in progress and available
// messages for a queue. These are returned as strings.
func (a *AwsConn) GetQueueDetails(url string) (string, string, error) {
	numAvailable := "ApproximateNumberOfMessages"
	numInProgress := "ApproximateNumberOfMessagesNotVisible"
	attrs, err := a.sqssvc.GetQueueAttributes(&sqs.GetQueueAttributesInput{
		AttributeNames: []*string{&numAvailable, &numInProgress},
		QueueUrl:       &url,
	})
	if err != nil {
		return "", "", errors.New(fmt.Sprintf("Failed to get queue attributes: %s", err))
	}
	return *attrs.Attributes[numAvailable], *attrs.Attributes[numInProgress], nil
}

func (a *AwsConn) PreQueueId() string {
	return a.prequrl
}

func (a *AwsConn) PreNoWipeQueueId() string {
	return a.prenwqurl
}

func (a *AwsConn) WipeQueueId() string {
	return a.wipequrl
}

func (a *AwsConn) OCRPageQueueId() string {
	return a.ocrpgqurl
}

func (a *AwsConn) AnalyseQueueId() string {
	return a.analysequrl
}

func (a *AwsConn) WIPStorageId() string {
	return a.wipstorageid
}

func (a *AwsConn) TestQueueId() string {
	return a.testqurl
}

func (a *AwsConn) ListObjects(bucket string, prefix string) ([]string, error) {
	var names []string
	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
		Bucket: aws.String(bucket),
		Prefix: aws.String(prefix),
	}, func(page *s3.ListObjectsV2Output, last bool) bool {
		for _, r := range page.Contents {
			names = append(names, *r.Key)
		}
		return true
	})
	return names, err
}

func (a *AwsConn) ListObjectsWithMeta(bucket string, prefix string) ([]ObjMeta, error) {
	var objs []ObjMeta
	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
		Bucket: aws.String(bucket),
		Prefix: aws.String(prefix),
	}, func(page *s3.ListObjectsV2Output, last bool) bool {
		for _, r := range page.Contents {
			objs = append(objs, ObjMeta{Name: *r.Key, Date: *r.LastModified})
		}
		return true
	})
	return objs, err
}

// ListObjectWithMeta lists the name and last modified date of the
// first object with the specified prefix.
func (a *AwsConn) ListObjectWithMeta(bucket string, prefix string) (ObjMeta, error) {
	var obj ObjMeta
	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
		Bucket:  aws.String(bucket),
		Prefix:  aws.String(prefix),
		MaxKeys: aws.Int64(1),
	}, func(page *s3.ListObjectsV2Output, last bool) bool {
		for _, r := range page.Contents {
			obj = ObjMeta{Name: *r.Key, Date: *r.LastModified}
		}
		return false
	})
	if obj.Name == "" && obj.Date.IsZero() && err == nil {
		return obj, fmt.Errorf("No object could be found for %s", prefix)
	}
	return obj, err
}

func (a *AwsConn) ListObjectPrefixes(bucket string) ([]string, error) {
	var prefixes []string
	err := a.s3svc.ListObjectsV2Pages(&s3.ListObjectsV2Input{
		Bucket:    aws.String(bucket),
		Delimiter: aws.String("/"),
	}, func(page *s3.ListObjectsV2Output, last bool) bool {
		for _, r := range page.CommonPrefixes {
			prefixes = append(prefixes, *r.Prefix)
		}
		return true
	})
	return prefixes, err
}

// Deletes a list of objects
func (a *AwsConn) DeleteObjects(bucket string, keys []string) error {
	objs := []*s3.ObjectIdentifier{}
	for i, v := range keys {
		o := s3.ObjectIdentifier{Key: aws.String(v)}
		objs = append(objs, &o)
		// s3.DeleteObjects can only take up to 1000 keys at a time,
		// so if necessary delete those collected so far and empty
		// the objs queue
		if i%1000 == 1 {
			_, err := a.s3svc.DeleteObjects(&s3.DeleteObjectsInput{
				Bucket: aws.String(bucket),
				Delete: &s3.Delete{
					Objects: objs,
					Quiet:   aws.Bool(true),
				},
			})
			if err != nil {
				return err
			}
			objs = []*s3.ObjectIdentifier{}
		}
	}
	_, err := a.s3svc.DeleteObjects(&s3.DeleteObjectsInput{
		Bucket: aws.String(bucket),
		Delete: &s3.Delete{
			Objects: objs,
			Quiet:   aws.Bool(true),
		},
	})
	return err
}

// CreateBucket creates a new S3 bucket
func (a *AwsConn) CreateBucket(name string) error {
	_, err := a.s3svc.CreateBucket(&s3.CreateBucketInput{
		Bucket: aws.String(name),
	})
	if err != nil {
		aerr, ok := err.(awserr.Error)
		if ok && (aerr.Code() == s3.ErrCodeBucketAlreadyExists || aerr.Code() == s3.ErrCodeBucketAlreadyOwnedByYou) {
			a.Logger.Println("Bucket already exists:", name)
		} else {
			return errors.New(fmt.Sprintf("Error creating bucket %s: %v", name, err))
		}
	}
	return nil
}

// CreateQueue creates a new SQS queue
// Note the queue attributes are currently hardcoded; it may make sense
// to specify them as arguments in the future.
func (a *AwsConn) CreateQueue(name string) error {
	_, err := a.sqssvc.CreateQueue(&sqs.CreateQueueInput{
		QueueName: aws.String(name),
		Attributes: map[string]*string{
			"VisibilityTimeout":             aws.String("120"),     // 2 minutes
			"MessageRetentionPeriod":        aws.String("1209600"), // 14 days; max allowed by sqs
			"ReceiveMessageWaitTimeSeconds": aws.String("20"),
		},
	})
	if err != nil {
		aerr, ok := err.(awserr.Error)
		// Note the QueueAlreadyExists code is only emitted if an existing queue
		// has different attributes than the one that was being created. SQS just
		// quietly ignores the CreateQueue request if it is identical to an
		// existing queue.
		if ok && aerr.Code() == sqs.ErrCodeQueueNameExists {
			return errors.New("Error: Queue already exists but has different attributes:" + name)
		} else {
			return errors.New(fmt.Sprintf("Error creating queue %s: %v", name, err))
		}
	}
	return nil
}

func (a *AwsConn) AddToQueue(url string, msg string) error {
	_, err := a.sqssvc.SendMessage(&sqs.SendMessageInput{
		MessageBody: &msg,
		QueueUrl:    &url,
	})
	return err
}

func (a *AwsConn) DelFromQueue(url string, handle string) error {
	_, err := a.sqssvc.DeleteMessage(&sqs.DeleteMessageInput{
		QueueUrl:      &url,
		ReceiptHandle: &handle,
	})
	return err
}

func (a *AwsConn) Download(bucket string, key string, path string) error {
	f, err := os.Create(path)
	if err != nil {
		return err
	}
	defer f.Close()

	_, err = a.downloader.Download(f,
		&s3.GetObjectInput{
			Bucket: aws.String(bucket),
			Key:    &key,
		})
	if err != nil {
		_ = os.Remove(path)
	}
	return err
}

func (a *AwsConn) Upload(bucket string, key string, path string) error {
	file, err := os.Open(path)
	if err != nil {
		return err
	}
	defer file.Close()

	_, err = a.uploader.Upload(&s3manager.UploadInput{
		Bucket: aws.String(bucket),
		Key:    aws.String(key),
		Body:   file,
	})
	return err
}

func (a *AwsConn) GetLogger() *log.Logger {
	return a.Logger
}

func instanceDetailsFromPage(page *ec2.DescribeInstancesOutput) []InstanceDetails {
	var details []InstanceDetails
	for _, r := range page.Reservations {
		for _, i := range r.Instances {
			var d InstanceDetails

			for _, t := range i.Tags {
				if *t.Key == "Name" {
					d.Name = *t.Value
				}
			}
			if i.PublicIpAddress != nil {
				d.Ip = *i.PublicIpAddress
			}
			if i.SpotInstanceRequestId != nil {
				d.Spot = *i.SpotInstanceRequestId
			}
			d.Type = *i.InstanceType
			d.Id = *i.InstanceId
			d.LaunchTime = i.LaunchTime.String()
			d.State = *i.State.Name

			details = append(details, d)
		}
	}

	return details
}

func (a *AwsConn) GetInstanceDetails() ([]InstanceDetails, error) {
	var details []InstanceDetails
	err := a.ec2svc.DescribeInstancesPages(&ec2.DescribeInstancesInput{}, func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
		for _, d := range instanceDetailsFromPage(page) {
			details = append(details, d)
		}
		return !lastPage
	})
	return details, err
}

func (a *AwsConn) StartInstances(n int) error {
	_, err := a.ec2svc.RequestSpotInstances(&ec2.RequestSpotInstancesInput{
		InstanceCount: aws.Int64(int64(n)),
		LaunchSpecification: &ec2.RequestSpotLaunchSpecification{
			IamInstanceProfile: &ec2.IamInstanceProfileSpecification{
				Arn: aws.String(spotProfile),
			},
			ImageId:      aws.String(spotImage),
			InstanceType: aws.String(spotType),
			SecurityGroupIds: []*string{
				aws.String(spotSg),
			},
		},
		Type: aws.String("one-time"),
	})
	return err
}

// Log records an item in the with the Logger. Arguments are handled
// as with fmt.Println.
func (a *AwsConn) Log(v ...interface{}) {
	a.Logger.Println(v...)
}

// mkpipeline sets up necessary buckets and queues for the pipeline
// TODO: also set up the necessary security group and iam stuff
func (a *AwsConn) MkPipeline() error {
	buckets := []string{storageWip}
	queues := []string{queuePreProc, queuePreNoWipe, queueWipeOnly, queueAnalyse, queueOcrPage, queueTest}

	for _, bucket := range buckets {
		err := a.CreateBucket(bucket)
		if err != nil {
			return err
		}
	}

	for _, queue := range queues {
		err := a.CreateQueue(queue)
		if err != nil {
			return err
		}
	}

	return nil
}