package main
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"os/signal"
"sync"
"time"
"github.com/go-redis/redis/v8"
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
)
var (
flagPort = flag.String("port", "8009", "determines the port the server will listen on")
flagInterval = flag.Float64("interval", 10, "In seconds, the delay between checking prices")
redisDB = flag.Int64("redisdb", 1, "determines the db number")
rdb *redis.Client
redisAddress = flag.String("redisaddress", "redis:6379", "determines the address of the redis instance")
redisPassword = flag.String("redispassword", "", "determines the password of the redis db")
)
const (
SERVER_DEPLOYMENT_TYPE = "SERVER_DEPLOYMENT_TYPE"
coingeckoAPIURLv3 = "https://api.coingecko.com/api/v3"
)
type HttpHandlerFunc func(http.ResponseWriter, *http.Request)
type HttpHandler struct {
name string
function HttpHandlerFunc
}
type priceChanStruct struct {
name string
price float64
}
type errorChanStruct struct {
hasError bool
err error
}
// OWASP: https://cheatsheetseries.owasp.org/cheatsheets/REST_Security_Cheat_Sheet.html
func addSecureHeaders(w *http.ResponseWriter) {
(*w).Header().Set("Cache-Control", "no-store")
(*w).Header().Set("Content-Security-Policy", "default-src https;")
(*w).Header().Set("Strict-Transport-Security", "max-age=63072000;")
(*w).Header().Set("X-Content-Type-Options", "nosniff")
(*w).Header().Set("X-Frame-Options", "DENY")
(*w).Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS")
}
//binance
func getPriceFromBinance(name, unit string,
wg *sync.WaitGroup,
priceChan chan<- priceChanStruct,
errChan chan<- errorChanStruct) {
}
//kucoin
func getPriceFromKu(name, uni string,
wg *sync.WaitGroup,
priceChan chan<- priceChanStruct,
errChan chan<- errorChanStruct) {
}
func getPriceFromCoinGecko(
name, unit string,
wg *sync.WaitGroup,
priceChan chan<- priceChanStruct,
errChan chan<- errorChanStruct) {
defer wg.Done()
params := "/simple/price?ids=" + url.QueryEscape(name) + "&" +
"vs_currencies=" + url.QueryEscape(unit)
path := coingeckoAPIURLv3 + params
fmt.Println(path)
resp, err := http.Get(path)
if err != nil {
priceChan <- priceChanStruct{name: name, price: 0.}
errChan <- errorChanStruct{hasError: true, err: err}
log.Error().Err(err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
priceChan <- priceChanStruct{name: name, price: 0.}
errChan <- errorChanStruct{hasError: true, err: err}
log.Error().Err(err)
}
jsonBody := make(map[string]interface{})
err = json.Unmarshal(body, &jsonBody)
if err != nil {
priceChan <- priceChanStruct{name: name, price: 0.}
errChan <- errorChanStruct{hasError: true, err: err}
log.Error().Err(err)
}
price := jsonBody[name].(map[string]interface{})[unit].(float64)
log.Info().Msg(string(body))
priceChan <- priceChanStruct{name: name, price: price}
errChan <- errorChanStruct{hasError: false, err: nil}
}
func arbHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
if r.Method != "GET" {
http.Error(w, "Method is not supported.", http.StatusNotFound)
}
addSecureHeaders(&w)
var name string
var unit string
params := r.URL.Query()
for key, value := range params {
switch key {
case "name":
name = value[0]
case "unit":
unit = value[0]
default:
log.Error().Err(errors.New("Got unexpected parameter."))
}
}
priceChan := make(chan priceChanStruct, 1)
errChan := make(chan errorChanStruct, 1)
var wg sync.WaitGroup
wg.Add(1)
getPriceFromCoinGecko(name, unit, &wg, priceChan, errChan)
wg.Wait()
select {
case err := <-errChan:
if err.hasError != false {
log.Error().Err(err.err)
}
default:
log.Error().Err(errors.New("We shouldnt be here"))
}
var price priceChanStruct
select {
case priceCh := <-priceChan:
price = priceCh
default:
log.Fatal().Err(errors.New("We shouldnt be here"))
}
json.NewEncoder(w).Encode(map[string]interface{}{
"name": price.name,
"price": price.price,
"unit": unit,
"err": "",
"isSuccessful": true,
})
}
func setupLogging() {
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
}
func startServer(gracefulWait time.Duration,
handlers []HttpHandler,
serverDeploymentType string, port string) {
r := mux.NewRouter()
cfg := &tls.Config{
MinVersion: tls.VersionTLS13,
}
srv := &http.Server{
Addr: "0.0.0.0:" + port,
WriteTimeout: time.Second * 15,
ReadTimeout: time.Second * 15,
Handler: r,
TLSConfig: cfg,
}
for i := 0; i < len(handlers); i++ {
r.HandleFunc(handlers[i].name, handlers[i].function)
}
go func() {
var certPath, keyPath string
if os.Getenv(serverDeploymentType) == "deployment" {
certPath = "/certs/fullchain1.pem"
keyPath = "/certs/privkey1.pem"
} else if os.Getenv(serverDeploymentType) == "test" {
certPath = "/certs/server.cert"
keyPath = "/certs/server.key"
} else {
log.Fatal().Err(errors.New(fmt.Sprintf("unknown deployment kind: %s", serverDeploymentType)))
}
if err := srv.ListenAndServeTLS(certPath, keyPath); err != nil {
log.Fatal().Err(err)
}
}()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
<-c
ctx, cancel := context.WithTimeout(context.Background(), gracefulWait)
defer cancel()
srv.Shutdown(ctx)
log.Info().Msg("gracefully shut down the server")
}
func main() {
var gracefulWait time.Duration
flag.DurationVar(&gracefulWait, "gracefulwait", time.Second*15, "the duration to wait during the graceful shutdown")
flag.Parse()
rdb = redis.NewClient(&redis.Options{
Addr: *redisAddress,
Password: *redisPassword,
DB: int(*redisDB),
})
defer rdb.Close()
setupLogging()
var handlerFuncs = []HttpHandler{{name: "/arb", function: arbHandler}}
startServer(gracefulWait, handlerFuncs, SERVER_DEPLOYMENT_TYPE, *flagPort)
}