From dc9d2911f67d8c7da4d6b761afc7cf882f21b8a7 Mon Sep 17 00:00:00 2001
From: Nick White <git@njw.name>
Date: Tue, 24 Sep 2019 15:05:02 +0100
Subject: Move ec2 stuff out of lspipeline and into aws.go

---
 bookpipeline/aws.go                 | 41 +++++++++++++++++++
 bookpipeline/cmd/lspipeline/main.go | 78 ++++++++-----------------------------
 2 files changed, 57 insertions(+), 62 deletions(-)

diff --git a/bookpipeline/aws.go b/bookpipeline/aws.go
index f3cdbfa..063bc9f 100644
--- a/bookpipeline/aws.go
+++ b/bookpipeline/aws.go
@@ -10,6 +10,7 @@ import (
 	"github.com/aws/aws-sdk-go/aws"
 	"github.com/aws/aws-sdk-go/aws/awserr"
 	"github.com/aws/aws-sdk-go/aws/session"
+	"github.com/aws/aws-sdk-go/service/ec2"
 	"github.com/aws/aws-sdk-go/service/s3"
 	"github.com/aws/aws-sdk-go/service/s3/s3manager"
 	"github.com/aws/aws-sdk-go/service/sqs"
@@ -22,6 +23,10 @@ type Qmsg struct {
 	Id, Handle, Body string
 }
 
+type InstanceDetails struct {
+	Id, Name, Ip, Spot, Type, State, LaunchTime string
+}
+
 type AwsConn struct {
 	// these need to be set before running Init()
 	Region string
@@ -29,6 +34,7 @@ type AwsConn struct {
 
 	// these are used internally
 	sess                          *session.Session
+	ec2svc                        *ec2.EC2
 	s3svc                         *s3.S3
 	sqssvc                        *sqs.SQS
 	downloader                    *s3manager.Downloader
@@ -37,6 +43,7 @@ type AwsConn struct {
 	wipstorageid                  string
 }
 
+// TODO: split this up, as not everything is needed for different uses
 func (a *AwsConn) Init() error {
 	if a.Region == "" {
 		return errors.New("No Region set")
@@ -52,6 +59,7 @@ func (a *AwsConn) Init() error {
 	if err != nil {
 		return errors.New(fmt.Sprintf("Failed to set up aws session: %s", err))
 	}
+	a.ec2svc = ec2.New(a.sess)
 	a.s3svc = s3.New(a.sess)
 	a.sqssvc = sqs.New(a.sess)
 	a.downloader = s3manager.NewDownloader(a.sess)
@@ -259,3 +267,36 @@ func (a *AwsConn) Upload(bucket string, key string, path string) error {
 func (a *AwsConn) GetLogger() *log.Logger {
 	return a.Logger
 }
+
+// TODO: split pages function so it can be encapsulated by
+//       downstream and to feed a channel
+func (a *AwsConn) GetInstanceDetails() ([]InstanceDetails, error) {
+	var details []InstanceDetails
+	err := a.ec2svc.DescribeInstancesPages(&ec2.DescribeInstancesInput{}, func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
+		for _, r := range page.Reservations {
+			for _, i := range r.Instances {
+				var d InstanceDetails
+
+				for _, t := range i.Tags {
+					if *t.Key == "Name" {
+						d.Name = *t.Value
+					}
+				}
+				if i.PublicIpAddress != nil {
+					d.Ip = *i.PublicIpAddress
+				}
+				if i.SpotInstanceRequestId != nil {
+					d.Spot = *i.SpotInstanceRequestId
+				}
+				d.Type = *i.InstanceType
+				d.Id = *i.InstanceId
+				d.LaunchTime = i.LaunchTime.String()
+				d.State = *i.State.Name
+
+				details = append(details, d)
+			}
+		}
+		return !lastPage
+	})
+	return details, err
+}
diff --git a/bookpipeline/cmd/lspipeline/main.go b/bookpipeline/cmd/lspipeline/main.go
index d49b933..3cbc893 100644
--- a/bookpipeline/cmd/lspipeline/main.go
+++ b/bookpipeline/cmd/lspipeline/main.go
@@ -7,12 +7,6 @@ import (
 	"os"
 
 	"rescribe.xyz/go.git/bookpipeline"
-
-	// TODO: abstract out the aws stuff into aws.go in due course
-	"github.com/aws/aws-sdk-go/aws"
-	"github.com/aws/aws-sdk-go/aws/session"
-	"github.com/aws/aws-sdk-go/service/ec2"
-	//"github.com/aws/aws-sdk-go/service/s3"
 )
 
 const usage = `Usage: lspipeline [-v]
@@ -32,6 +26,7 @@ type LsPipeliner interface {
 	OCRQueueId() string
 	AnalyseQueueId() string
 	GetQueueDetails(url string) (string, string, error)
+	GetInstanceDetails() ([]bookpipeline.InstanceDetails, error)
 }
 
 // NullWriter is used so non-verbose logging may be discarded
@@ -41,52 +36,19 @@ func (w NullWriter) Write(p []byte) (n int, err error) {
 	return len(p), nil
 }
 
-type instanceDetails struct {
-	id, name, ip, spot, iType, state, launchTime string
-}
-
 type queueDetails struct {
 	name, numAvailable, numInProgress string
 }
 
-func ec2getInstances(svc *ec2.EC2, instances chan instanceDetails) {
-	err := svc.DescribeInstancesPages(&ec2.DescribeInstancesInput{}, parseInstances(instances))
+func getInstances(conn LsPipeliner, detailsc chan bookpipeline.InstanceDetails) {
+	details, err := conn.GetInstanceDetails()
 	if err != nil {
-		close(instances)
-		log.Println("Error with ec2 DescribeInstancePages call:", err)
+		log.Println("Error getting instance details:", err)
 	}
-}
-
-func parseInstances(details chan instanceDetails) (func(*ec2.DescribeInstancesOutput, bool) bool) {
-	return func(page *ec2.DescribeInstancesOutput, lastPage bool) bool {
-		for _, r := range page.Reservations {
-			for _, i := range r.Instances {
-				var d instanceDetails
-
-				for _, t := range i.Tags {
-					if *t.Key == "Name" {
-						d.name = *t.Value
-					}
-				}
-				if i.PublicIpAddress != nil {
-					d.ip = *i.PublicIpAddress
-				}
-				if i.SpotInstanceRequestId != nil {
-					d.spot = *i.SpotInstanceRequestId
-				}
-				d.iType = *i.InstanceType
-				d.id = *i.InstanceId
-				d.launchTime = i.LaunchTime.String()
-				d.state = *i.State.Name
-
-				details <- d
-			}
-		}
-		if lastPage {
-			close(details)
-		}
-		return !lastPage
+	for _, d := range details {
+		detailsc <- d
 	}
+	close(detailsc)
 }
 
 func getQueueDetails(conn LsPipeliner, qdetails chan queueDetails) {
@@ -132,31 +94,23 @@ func main() {
 		log.Fatalln("Failed to set up cloud connection:", err)
 	}
 
-	sess, err := session.NewSession(&aws.Config{
-		Region: aws.String("eu-west-2"),
-	})
-	if err != nil {
-		log.Fatalln("Failed to set up aws session", err)
-	}
-	ec2svc := ec2.New(sess)
-
-	instances := make(chan instanceDetails, 100)
+	instances := make(chan bookpipeline.InstanceDetails, 100)
 	queues := make(chan queueDetails)
 
-	go ec2getInstances(ec2svc, instances)
+	go getInstances(conn, instances)
 	go getQueueDetails(conn, queues)
 
 	fmt.Println("# Instances")
 	for i := range instances {
-		fmt.Printf("ID: %s, Type: %s, LaunchTime: %s, State: %s", i.id, i.iType, i.launchTime, i.state)
-		if i.name != "" {
-			fmt.Printf(", Name: %s", i.name)
+		fmt.Printf("ID: %s, Type: %s, LaunchTime: %s, State: %s", i.Id, i.Type, i.LaunchTime, i.State)
+		if i.Name != "" {
+			fmt.Printf(", Name: %s", i.Name)
 		}
-		if i.ip != "" {
-			fmt.Printf(", IP: %s", i.ip)
+		if i.Ip != "" {
+			fmt.Printf(", IP: %s", i.Ip)
 		}
-		if i.spot != "" {
-			fmt.Printf(", SpotRequest: %s", i.spot)
+		if i.Spot != "" {
+			fmt.Printf(", SpotRequest: %s", i.Spot)
 		}
 		fmt.Printf("\n")
 	}
-- 
cgit v1.2.1-24-ge1ad