aboutsummaryrefslogblamecommitdiffstats
path: root/arbiter/arbiter.go
blob: 9cb3ad069a61c82d1a3d3372b27e289e23a2e7d9 (plain) (tree)
1
2
3
4
5
6
7
8
9




                    
                       


                
                   
                  
                 

                   
              


















                                                                                                                 
                                                                   

 
















                                                             









                                                                                        

























































                                                                        





                                                                              













































                                                                                


 






                                                        


                                             

                            
                                                




                                               



                                                                    


                                            
                                                                    

                                                          
                                                                     


                                                       
                                                                                                                     















                                                                                











                                                                                                                            


                                                                                  
 
package main

import (
	"context"
	"crypto/tls"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"net/http"
	"net/url"
	"os"
	"os/signal"
	"sync"
	"time"

	"github.com/go-redis/redis/v8"
	"github.com/gorilla/mux"
	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"
)

var (
	flagPort      = flag.String("port", "8009", "determines the port the server will listen on")
	flagInterval  = flag.Float64("interval", 10, "In seconds, the delay between checking prices")
	redisDB       = flag.Int64("redisdb", 1, "determines the db number")
	rdb           *redis.Client
	redisAddress  = flag.String("redisaddress", "redis:6379", "determines the address of the redis instance")
	redisPassword = flag.String("redispassword", "", "determines the password of the redis db")
)

const (
	SERVER_DEPLOYMENT_TYPE = "SERVER_DEPLOYMENT_TYPE"
	coingeckoAPIURLv3      = "https://api.coingecko.com/api/v3"
)

type HttpHandlerFunc func(http.ResponseWriter, *http.Request)

type HttpHandler struct {
	name     string
	function HttpHandlerFunc
}

type priceChanStruct struct {
	name  string
	price float64
}

type errorChanStruct struct {
	hasError bool
	err      error
}

// OWASP: https://cheatsheetseries.owasp.org/cheatsheets/REST_Security_Cheat_Sheet.html
func addSecureHeaders(w *http.ResponseWriter) {
	(*w).Header().Set("Cache-Control", "no-store")
	(*w).Header().Set("Content-Security-Policy", "default-src https;")
	(*w).Header().Set("Strict-Transport-Security", "max-age=63072000;")
	(*w).Header().Set("X-Content-Type-Options", "nosniff")
	(*w).Header().Set("X-Frame-Options", "DENY")
	(*w).Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
}

//binance
func getPriceFromBinance(name, unit string,
	wg *sync.WaitGroup,
	priceChan chan<- priceChanStruct,
	errChan chan<- errorChanStruct) {

}

//kucoin
func getPriceFromKu(name, uni string,
	wg *sync.WaitGroup,
	priceChan chan<- priceChanStruct,
	errChan chan<- errorChanStruct) {

}

func getPriceFromCoinGecko(
	name, unit string,
	wg *sync.WaitGroup,
	priceChan chan<- priceChanStruct,
	errChan chan<- errorChanStruct) {
	defer wg.Done()

	params := "/simple/price?ids=" + url.QueryEscape(name) + "&" +
		"vs_currencies=" + url.QueryEscape(unit)
	path := coingeckoAPIURLv3 + params
	fmt.Println(path)
	resp, err := http.Get(path)
	if err != nil {
		priceChan <- priceChanStruct{name: name, price: 0.}
		errChan <- errorChanStruct{hasError: true, err: err}
		log.Error().Err(err)
	}
	defer resp.Body.Close()

	body, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		priceChan <- priceChanStruct{name: name, price: 0.}
		errChan <- errorChanStruct{hasError: true, err: err}
		log.Error().Err(err)
	}

	jsonBody := make(map[string]interface{})
	err = json.Unmarshal(body, &jsonBody)
	if err != nil {
		priceChan <- priceChanStruct{name: name, price: 0.}
		errChan <- errorChanStruct{hasError: true, err: err}
		log.Error().Err(err)
	}

	price := jsonBody[name].(map[string]interface{})[unit].(float64)

	log.Info().Msg(string(body))

	priceChan <- priceChanStruct{name: name, price: price}
	errChan <- errorChanStruct{hasError: false, err: nil}
}

func arbHandler(w http.ResponseWriter, r *http.Request) {
	w.Header().Add("Content-Type", "application/json")
	if r.Method != "GET" {
		http.Error(w, "Method is not supported.", http.StatusNotFound)
	}
	addSecureHeaders(&w)

	var name string
	var unit string
	params := r.URL.Query()
	for key, value := range params {
		switch key {
		case "name":
			name = value[0]
		case "unit":
			unit = value[0]
		default:
			log.Error().Err(errors.New("Got unexpected parameter."))
		}
	}

	priceChan := make(chan priceChanStruct, 1)
	errChan := make(chan errorChanStruct, 1)
	var wg sync.WaitGroup
	wg.Add(1)
	getPriceFromCoinGecko(name, unit, &wg, priceChan, errChan)
	wg.Wait()

	select {
	case err := <-errChan:
		if err.hasError != false {
			log.Error().Err(err.err)
		}
	default:
		log.Error().Err(errors.New("We shouldnt be here"))
	}

	var price priceChanStruct
	select {
	case priceCh := <-priceChan:
		price = priceCh
	default:
		log.Fatal().Err(errors.New("We shouldnt be here"))
	}

	json.NewEncoder(w).Encode(map[string]interface{}{
		"name":         price.name,
		"price":        price.price,
		"unit":         unit,
		"err":          "",
		"isSuccessful": true,
	})

}

func setupLogging() {
	zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
}

func startServer(gracefulWait time.Duration,
	handlers []HttpHandler,
	serverDeploymentType string, port string) {
	r := mux.NewRouter()
	cfg := &tls.Config{
		MinVersion: tls.VersionTLS13,
	}
	srv := &http.Server{
		Addr:         "0.0.0.0:" + port,
		WriteTimeout: time.Second * 15,
		ReadTimeout:  time.Second * 15,
		Handler:      r,
		TLSConfig:    cfg,
	}

	for i := 0; i < len(handlers); i++ {
		r.HandleFunc(handlers[i].name, handlers[i].function)
	}

	go func() {
		var certPath, keyPath string
		if os.Getenv(serverDeploymentType) == "deployment" {
			certPath = "/certs/fullchain1.pem"
			keyPath = "/certs/privkey1.pem"
		} else if os.Getenv(serverDeploymentType) == "test" {
			certPath = "/certs/server.cert"
			keyPath = "/certs/server.key"
		} else {
			log.Fatal().Err(errors.New(fmt.Sprintf("unknown deployment kind: %s", serverDeploymentType)))
		}
		if err := srv.ListenAndServeTLS(certPath, keyPath); err != nil {
			log.Fatal().Err(err)
		}
	}()

	c := make(chan os.Signal, 1)

	signal.Notify(c, os.Interrupt)
	<-c
	ctx, cancel := context.WithTimeout(context.Background(), gracefulWait)
	defer cancel()
	srv.Shutdown(ctx)
	log.Info().Msg("gracefully shut down the server")
}

func main() {
	var gracefulWait time.Duration
	flag.DurationVar(&gracefulWait, "gracefulwait", time.Second*15, "the duration to wait during the graceful shutdown")
	flag.Parse()
	rdb = redis.NewClient(&redis.Options{
		Addr:     *redisAddress,
		Password: *redisPassword,
		DB:       int(*redisDB),
	})
	defer rdb.Close()

	setupLogging()
	var handlerFuncs = []HttpHandler{{name: "/arb", function: arbHandler}}

	startServer(gracefulWait, handlerFuncs, SERVER_DEPLOYMENT_TYPE, *flagPort)
}