diff options
author | terminaldweller <thabogre@gmail.com> | 2023-02-28 11:39:13 +0000 |
---|---|---|
committer | terminaldweller <thabogre@gmail.com> | 2023-02-28 11:39:13 +0000 |
commit | 60194abb983f4d011fd5ddbc275ead49afce8bf6 (patch) | |
tree | 2b94862987568ecaff6c260c33c0219b6c0581b6 | |
parent | added socks5 support for arbiter (diff) | |
download | hived-60194abb983f4d011fd5ddbc275ead49afce8bf6.tar.gz hived-60194abb983f4d011fd5ddbc275ead49afce8bf6.zip |
golang lint fixes
-rw-r--r-- | arbiter/arbiter.go | 377 | ||||
-rw-r--r-- | hived/go.mod | 4 | ||||
-rw-r--r-- | hived/hived.go | 825 | ||||
-rw-r--r-- | hived/hived_test.go | 23 |
4 files changed, 772 insertions, 457 deletions
diff --git a/arbiter/arbiter.go b/arbiter/arbiter.go index 5168b3a..9edf68a 100644 --- a/arbiter/arbiter.go +++ b/arbiter/arbiter.go @@ -7,7 +7,7 @@ import ( "errors" "flag" "fmt" - "io/ioutil" + "io" "net" "net/http" "net/url" @@ -25,23 +25,25 @@ import ( ) var ( - flagPort = flag.String("port", "8009", "determines the port the server will listen on") - flagInterval = flag.Float64("interval", 10, "In seconds, the delay between checking prices") - redisDB = flag.Int64("redisdb", 1, "determines the db number") - rdb *redis.Client - redisAddress = flag.String("redisaddress", "redis:6379", "determines the address of the redis instance") - redisPassword = flag.String("redispassword", "", "determines the password of the redis db") + errBadLogic = errors.New("we should not be here") + errUnexpectedParam = errors.New("got unexpected parameter") + errUnknownDeployment = errors.New("unknown deployment kind") ) const ( - SERVER_DEPLOYMENT_TYPE = "SERVER_DEPLOYMENT_TYPE" - coingeckoAPIURLv3 = "https://api.coingecko.com/api/v3" - coincapAPIURLv2 = "https://api.coincap.io/v2" + serverDeploymentType = "SERVER_DEPLOYMENT_TYPE" + coingeckoAPIURLv3 = "https://api.coingecko.com/api/v3" + coincapAPIURLv2 = "https://api.coincap.io/v2" + getTimeout = 5 + httpClientTimeout = 5 + serverTLSReadTimeout = 15 + serverTLSWriteTimeout = 15 + defaultGracefulShutdown = 15 ) // https://docs.coincap.io/ type CoinCapAssetGetResponseData struct { - Id string `json:"id"` + ID string `json:"id"` Rank string `json:"rank"` Symbol string `json:"symbol"` Name string `json:"name"` @@ -54,16 +56,24 @@ type CoinCapAssetGetResponseData struct { Vwap24Hr string `json:"vwap24Hr"` } +type priceResponseData struct { + Name string `json:"name"` + Price float64 `json:"price"` + Unit string `json:"unit"` + Err string `json:"err"` + IsSuccessful bool `json:"isSuccessful"` +} + type CoinCapAssetGetResponse struct { Data CoinCapAssetGetResponseData `json:"data"` TimeStamp int64 `json:"timestamp"` } -type HttpHandlerFunc func(http.ResponseWriter, *http.Request) +type HTTPHandlerFunc func(http.ResponseWriter, *http.Request) -type HttpHandler struct { +type HTTPHandler struct { name string - function HttpHandlerFunc + function HTTPHandlerFunc } type priceChanStruct struct { @@ -81,171 +91,241 @@ func GetProxiedClient() (*http.Client, error) { if proxyURL == "" { proxyURL = os.Getenv("HTTPS_PROXY") } + dialer, err := proxy.SOCKS5("tcp", proxyURL, nil, proxy.Direct) if err != nil { - return nil, err + return nil, fmt.Errorf("[GetProxiedClient] : %w", err) } + dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { - return dialer.Dial(network, address) + netConn, err := dialer.Dial(network, address) + if err == nil { + return netConn, nil + } + + return netConn, fmt.Errorf("[dialContext] : %w", err) } transport := &http.Transport{ DialContext: dialContext, DisableKeepAlives: true, } - client := &http.Client{Transport: transport} + client := &http.Client{ + Transport: transport, + Timeout: httpClientTimeout * time.Second, + CheckRedirect: nil, + Jar: nil, + } return client, nil } // OWASP: https://cheatsheetseries.owasp.org/cheatsheets/REST_Security_Cheat_Sheet.html -func addSecureHeaders(w *http.ResponseWriter) { - (*w).Header().Set("Cache-Control", "no-store") - (*w).Header().Set("Content-Security-Policy", "default-src https;") - (*w).Header().Set("Strict-Transport-Security", "max-age=63072000;") - (*w).Header().Set("X-Content-Type-Options", "nosniff") - (*w).Header().Set("X-Frame-Options", "DENY") - (*w).Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS") +func addSecureHeaders(writer *http.ResponseWriter) { + (*writer).Header().Set("Cache-Control", "no-store") + (*writer).Header().Set("Content-Security-Policy", "default-src https;") + (*writer).Header().Set("Strict-Transport-Security", "max-age=63072000;") + (*writer).Header().Set("X-Content-Type-Options", "nosniff") + (*writer).Header().Set("X-Frame-Options", "DENY") + (*writer).Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS") } -// binance -func getPriceFromBinance(name, unit string, - wg *sync.WaitGroup, - priceChan chan<- priceChanStruct, - errChan chan<- errorChanStruct) { +// get price from binance. +// func getPriceFromBinance(name, unit string, +// wg *sync.WaitGroup, +// priceChan chan<- priceChanStruct, +// errChan chan<- errorChanStruct) { -} +// } -// kucoin -func getPriceFromKu(name, uni string, - wg *sync.WaitGroup, - priceChan chan<- priceChanStruct, - errChan chan<- errorChanStruct) { +// get price from kucoin. +// func getPriceFromKu(name, uni string, +// wg *sync.WaitGroup, +// priceChan chan<- priceChanStruct, +// errChan chan<- errorChanStruct) { -} +// } func getPriceFromCoinGecko( + ctx context.Context, name, unit string, wg *sync.WaitGroup, priceChan chan<- priceChanStruct, - errChan chan<- errorChanStruct) { + errChan chan<- errorChanStruct, +) { defer wg.Done() + priceFloat := 0. + params := "/simple/price?ids=" + url.QueryEscape(name) + "&" + "vs_currencies=" + url.QueryEscape(unit) path := coingeckoAPIURLv3 + params - fmt.Println(path) - // resp, err := http.Get(path) + client, err := GetProxiedClient() if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} + errChan <- errorChanStruct{hasError: true, err: err} + + log.Error().Err(err) + + return + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + if err != nil { + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) + return } - resp, err := client.Get(path) + resp, err := client.Do(req) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) - fmt.Println(err) + return } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) } jsonBody := make(map[string]interface{}) + err = json.Unmarshal(body, &jsonBody) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) } - price := jsonBody[name].(map[string]interface{})[unit].(float64) + price, isOk := jsonBody[name].(map[string]interface{}) + if !isOk { + priceChan <- priceChanStruct{name: name, price: priceFloat} + errChan <- errorChanStruct{hasError: true, err: err} + + log.Error().Err(err) + + return + } log.Info().Msg(string(body)) - priceChan <- priceChanStruct{name: name, price: price} + priceFloat, isOk = price[unit].(float64) + if !isOk { + priceChan <- priceChanStruct{name: name, price: priceFloat} + errChan <- errorChanStruct{hasError: true, err: err} + + log.Error().Err(err) + + return + } + + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: false, err: nil} } func getPriceFromCoinCap( + ctx context.Context, name, unit string, wg *sync.WaitGroup, priceChan chan<- priceChanStruct, - errChan chan<- errorChanStruct) { + errChan chan<- errorChanStruct, +) { defer wg.Done() + priceFloat := 0. + params := "/assets/" + url.QueryEscape(name) path := coincapAPIURLv2 + params - fmt.Println(path) + client, err := GetProxiedClient() if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} + errChan <- errorChanStruct{hasError: true, err: err} + + log.Error().Err(err) + + return + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) + if err != nil { + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) + return } - // resp, err := http.Get(path) - resp, err := client.Get(path) + + resp, err := client.Do(req) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) + return } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) } - fmt.Println(string(body)) var coinCapAssetGetResponse CoinCapAssetGetResponse - // jsonBody := make(map[string]interface{}) - // err = json.Unmarshal(body, &jsonBody) + err = json.Unmarshal(body, &coinCapAssetGetResponse) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) } - // price := jsonBody[name].(map[string]interface{})[unit].(float64) - price, err := strconv.ParseFloat(coinCapAssetGetResponse.Data.PriceUsd, 64) + priceFloat, err = strconv.ParseFloat(coinCapAssetGetResponse.Data.PriceUsd, 64) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: true, err: err} + log.Error().Err(err) } - fmt.Println(price) log.Info().Msg(string(body)) - priceChan <- priceChanStruct{name: name, price: price} + priceChan <- priceChanStruct{name: name, price: priceFloat} errChan <- errorChanStruct{hasError: false, err: nil} } func arbHandler(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") - if r.Method != "GET" { + + if r.Method != http.MethodGet { http.Error(w, "Method is not supported.", http.StatusNotFound) } + addSecureHeaders(&w) var name string + var unit string + params := r.URL.Query() for key, value := range params { switch key { @@ -254,24 +334,31 @@ func arbHandler(w http.ResponseWriter, r *http.Request) { case "unit": unit = value[0] default: - log.Error().Err(errors.New("Got unexpected parameter.")) + log.Error().Err(errUnexpectedParam) } } priceChan := make(chan priceChanStruct, 1) errChan := make(chan errorChanStruct, 1) - var wg sync.WaitGroup - wg.Add(1) - getPriceFromCoinGecko(name, unit, &wg, priceChan, errChan) - wg.Wait() + + var waitGroup sync.WaitGroup + + ctx, cancel := context.WithTimeout(context.Background(), getTimeout*time.Second) + defer cancel() + + waitGroup.Add(1) + + //nolint:contextcheck + getPriceFromCoinGecko(ctx, name, unit, &waitGroup, priceChan, errChan) + waitGroup.Wait() select { case err := <-errChan: - if err.hasError != false { + if err.hasError { log.Error().Err(err.err) } default: - log.Error().Err(errors.New("We shouldnt be here")) + log.Error().Err(errBadLogic) } var price priceChanStruct @@ -279,27 +366,44 @@ func arbHandler(w http.ResponseWriter, r *http.Request) { case priceCh := <-priceChan: price = priceCh default: - log.Fatal().Err(errors.New("We shouldnt be here")) + log.Error().Err(errBadLogic) } - json.NewEncoder(w).Encode(map[string]interface{}{ - "name": price.name, - "price": price.price, - "unit": unit, - "err": "", - "isSuccessful": true, - }) + responseData := priceResponseData{ + Name: price.name, + Price: price.price, + Unit: "USD", + Err: "", + IsSuccessful: true, + } + + jsonResp, err := json.Marshal(responseData) + if err != nil { + cancel() + //nolint:gocritic + log.Fatal().Err(err) + } + + _, err = w.Write(jsonResp) + if err != nil { + cancel() + log.Fatal().Err(err) + } } func coincapHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") - if r.Method != "GET" { + if r.Method != http.MethodGet { http.Error(w, "Method is not supported.", http.StatusNotFound) } + + w.Header().Add("Content-Type", "application/json") + addSecureHeaders(&w) var name string + var unit string + params := r.URL.Query() for key, value := range params { switch key { @@ -308,24 +412,31 @@ func coincapHandler(w http.ResponseWriter, r *http.Request) { case "unit": unit = value[0] default: - log.Error().Err(errors.New("Got unexpected parameter.")) + log.Error().Err(errUnexpectedParam) } } priceChan := make(chan priceChanStruct, 1) errChan := make(chan errorChanStruct, 1) - var wg sync.WaitGroup - wg.Add(1) - getPriceFromCoinCap(name, unit, &wg, priceChan, errChan) - wg.Wait() + + var waitGroup sync.WaitGroup + + waitGroup.Add(1) + + ctx, cancel := context.WithTimeout(context.Background(), getTimeout*time.Second) + defer cancel() + + //nolint:contextcheck + getPriceFromCoinCap(ctx, name, unit, &waitGroup, priceChan, errChan) + waitGroup.Wait() select { case err := <-errChan: - if err.hasError != false { + if err.hasError { log.Error().Err(err.err) } default: - log.Error().Err(errors.New("We shouldnt be here")) + log.Error().Err(errBadLogic) } var price priceChanStruct @@ -333,16 +444,29 @@ func coincapHandler(w http.ResponseWriter, r *http.Request) { case priceCh := <-priceChan: price = priceCh default: - log.Fatal().Err(errors.New("We shouldnt be here")) + log.Error().Err(errBadLogic) } - json.NewEncoder(w).Encode(map[string]interface{}{ - "name": price.name, - "price": price.price, - "unit": "USD", - "err": "", - "isSuccessful": true, - }) + responseData := priceResponseData{ + Name: price.name, + Price: price.price, + Unit: "USD", + Err: "", + IsSuccessful: true, + } + + jsonResp, err := json.Marshal(responseData) + if err != nil { + cancel() + //nolint:gocritic + log.Fatal().Err(err) + } + + _, err = w.Write(jsonResp) + if err != nil { + cancel() + log.Fatal().Err(err) + } } func setupLogging() { @@ -350,37 +474,42 @@ func setupLogging() { } func startServer(gracefulWait time.Duration, - handlers []HttpHandler, - serverDeploymentType string, port string) { - r := mux.NewRouter() + handlers []HTTPHandler, + serverDeploymentType string, port string, +) { + route := mux.NewRouter() cfg := &tls.Config{ MinVersion: tls.VersionTLS13, } + srv := &http.Server{ Addr: "0.0.0.0:" + port, - WriteTimeout: time.Second * 15, - ReadTimeout: time.Second * 15, - Handler: r, + WriteTimeout: time.Second * serverTLSWriteTimeout, + ReadTimeout: time.Second * serverTLSReadTimeout, + Handler: route, TLSConfig: cfg, } for i := 0; i < len(handlers); i++ { - r.HandleFunc(handlers[i].name, handlers[i].function) + route.HandleFunc(handlers[i].name, handlers[i].function) } go func() { var certPath, keyPath string - if os.Getenv(serverDeploymentType) == "deployment" { + + switch os.Getenv(serverDeploymentType) { + case "deployment": certPath = "/certs/fullchain1.pem" keyPath = "/certs/privkey1.pem" - } else if os.Getenv(serverDeploymentType) == "test" { + case "test": certPath = "/certs/server.cert" keyPath = "/certs/server.key" - } else { - log.Fatal().Err(errors.New(fmt.Sprintf("unknown deployment kind: %s", serverDeploymentType))) + default: + log.Error().Err(errUnknownDeployment) } + if err := srv.ListenAndServeTLS(certPath, keyPath); err != nil { - log.Fatal().Err(err) + log.Error().Err(err) } }() @@ -388,16 +517,35 @@ func startServer(gracefulWait time.Duration, signal.Notify(c, os.Interrupt) <-c + ctx, cancel := context.WithTimeout(context.Background(), gracefulWait) defer cancel() - srv.Shutdown(ctx) + + if err := srv.Shutdown(ctx); err != nil { + log.Error().Err(err) + } + log.Info().Msg("gracefully shut down the server") } func main() { var gracefulWait time.Duration - flag.DurationVar(&gracefulWait, "gracefulwait", time.Second*15, "the duration to wait during the graceful shutdown") + + var rdb *redis.Client + + flag.DurationVar( + &gracefulWait, + "gracefulwait", + time.Second*defaultGracefulShutdown, + "the duration to wait during the graceful shutdown", + ) + + flagPort := flag.String("port", "8009", "determines the port the server will listen on") + redisDB := flag.Int64("redisdb", 1, "determines the db number") + redisAddress := flag.String("redisaddress", "redis:6379", "determines the address of the redis instance") + redisPassword := flag.String("redispassword", "", "determines the password of the redis db") flag.Parse() + rdb = redis.NewClient(&redis.Options{ Addr: *redisAddress, Password: *redisPassword, @@ -406,10 +554,11 @@ func main() { defer rdb.Close() setupLogging() - var handlerFuncs = []HttpHandler{ + + handlerFuncs := []HTTPHandler{ {name: "/crypto/v1/arb/gecko", function: arbHandler}, {name: "/crypto/v1/arb/coincap", function: coincapHandler}, } - startServer(gracefulWait, handlerFuncs, SERVER_DEPLOYMENT_TYPE, *flagPort) + startServer(gracefulWait, handlerFuncs, serverDeploymentType, *flagPort) } diff --git a/hived/go.mod b/hived/go.mod index 126eedc..bbd4fa3 100644 --- a/hived/go.mod +++ b/hived/go.mod @@ -8,7 +8,9 @@ require ( github.com/gorilla/mux v1.8.0 github.com/rs/zerolog v1.20.0 github.com/terminaldweller/grpc v1.0.3 + golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb google.golang.org/grpc v1.42.0 + google.golang.org/protobuf v1.27.1 ) require ( @@ -18,9 +20,7 @@ require ( go.opentelemetry.io/otel v0.17.0 // indirect go.opentelemetry.io/otel/metric v0.17.0 // indirect go.opentelemetry.io/otel/trace v0.17.0 // indirect - golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb // indirect golang.org/x/sys v0.0.0-20210112080510-489259a85091 // indirect golang.org/x/text v0.3.3 // indirect google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 // indirect - google.golang.org/protobuf v1.27.1 // indirect ) diff --git a/hived/hived.go b/hived/hived.go index 927e633..42fb3b0 100644 --- a/hived/hived.go +++ b/hived/hived.go @@ -1,17 +1,13 @@ package main import ( - "bytes" "context" - "crypto/hmac" - "crypto/sha512" "crypto/tls" - "encoding/hex" "encoding/json" "errors" "flag" "fmt" - "io/ioutil" + "net" "net/http" "net/url" "os" @@ -26,41 +22,89 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" pb "github.com/terminaldweller/grpc/telebot/v1" + "golang.org/x/net/proxy" "google.golang.org/grpc" -) - -var ( - flagPort = flag.String("port", "8008", "determined the port the sercice runs on") - alertsCheckInterval = flag.Int64("alertinterval", 600., "in seconds, the amount of time between alert checks") - redisAddress = flag.String("redisaddress", "redis:6379", "determines the address of the redis instance") - redisPassword = flag.String("redispassword", "", "determines the password of the redis db") - redisDB = flag.Int64("redisdb", 0, "determines the db number") - botChannelID = flag.Int64("botchannelid", 146328407, "determines the channel id the telgram bot should send messages to") - cacheDuration = flag.Float64("cacheDuration", 300_000, "determines the price cache validity duration in miliseconds") - rdb *redis.Client + "google.golang.org/protobuf/types/known/timestamppb" ) const ( cryptocomparePriceURL = "https://min-api.cryptocompare.com/data/price?" - coingeckoAPIURLv3 = "https://api.coingecko.com/api/v3" - changellyURL = "https://api.changelly.com" - TELEGRAM_BOT_TOKEN_ENV_VAR = "TELEGRAM_BOT_TOKEN" - CHANGELLY_API_KEY_ENV_VAR = "CHANGELLY_API_KEY" - CHANGELLY_API_SECRET_ENV_VAR = "CHANGELLY_API_SECRET" - SERVER_DEPLOYMENT_TYPE = "SERVER_DEPLOYMENT_TYPE" + telegramBotTokenEnvVar = "TELEGRAM_BOT_TOKEN" //nolint: gosec + serverDeploymentType = "SERVER_DEPLOYMENT_TYPE" + httpClientTimeout = 5 + getTimeout = 5 + serverTLSReadTimeout = 15 + serverTLSWriteTimeout = 15 + defaultGracefulShutdown = 15 + redisContextTimeout = 2 + pingTimeout = 5 + alertCheckIntervalDefault = 600 + redisCacheDurationMultiplier = 1_000_000 + cacheDurationdefault = 300_000 + telegramTimeout = 10 + // coingeckoAPIURLv3 = "https://api.coingecko.com/api/v3" ) +var ( + cacheDuration = flag.Float64( + "cacheDuration", + cacheDurationdefault, + "determines the price cache validity duration in miliseconds", + ) + rdb *redis.Client + errUnknownParam = errors.New("unknown parameters for endpoint") + errIncompParams = errors.New("incomplete set of parameters") + errBadLogic = errors.New("bad logic") + errFailedTypeAssertion = errors.New("type assertion failed") + errFailedUnmarshall = errors.New("failed to unmarshall JSON") + errUnknownDeploymentKind = errors.New("unknown deployment kind") +) + +func GetProxiedClient() (*http.Client, error) { + proxyURL := os.Getenv("ALL_PROXY") + if proxyURL == "" { + proxyURL = os.Getenv("HTTPS_PROXY") + } + + dialer, err := proxy.SOCKS5("tcp", proxyURL, nil, proxy.Direct) + if err != nil { + return nil, fmt.Errorf("[GetProxiedClient] : %w", err) + } + + dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { + netConn, err := dialer.Dial(network, address) + if err == nil { + return netConn, nil + } + + return netConn, fmt.Errorf("[dialContext] : %w", err) + } + + transport := &http.Transport{ + DialContext: dialContext, + DisableKeepAlives: true, + } + client := &http.Client{ + Transport: transport, + Timeout: httpClientTimeout * time.Second, + CheckRedirect: nil, + Jar: nil, + } + + return client, nil +} + // OWASP: https://cheatsheetseries.owasp.org/cheatsheets/REST_Security_Cheat_Sheet.html -func addSecureHeaders(w *http.ResponseWriter) { - (*w).Header().Set("Cache-Control", "no-store") - (*w).Header().Set("Content-Security-Policy", "default-src https;") - (*w).Header().Set("Strict-Transport-Security", "max-age=63072000;") - (*w).Header().Set("X-Content-Type-Options", "nosniff") - (*w).Header().Set("X-Frame-Options", "DENY") - (*w).Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS") +func addSecureHeaders(writer *http.ResponseWriter) { + (*writer).Header().Set("Cache-Control", "no-store") + (*writer).Header().Set("Content-Security-Policy", "default-src https;") + (*writer).Header().Set("Strict-Transport-Security", "max-age=63072000;") + (*writer).Header().Set("X-Content-Type-Options", "nosniff") + (*writer).Header().Set("X-Frame-Options", "DENY") + (*writer).Header().Set("Access-Control-Allow-Methods", "GET,POST,PUT,DELETE,OPTIONS") } -func sendToTg(address, msg string, channelId int64) { +func sendToTg(address, msg string, channelID int64) { conn, err := grpc.Dial(address, grpc.WithInsecure()) if err != nil { log.Fatal().Err(err) @@ -69,15 +113,21 @@ func sendToTg(address, msg string, channelId int64) { c := pb.NewNotificationServiceClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), telegramTimeout*time.Second) defer cancel() - r, err := c.Notify(ctx, &pb.NotificationRequest{NotificationText: msg, ChannelId: channelId}) + response, err := c.Notify( + ctx, + &pb.NotificationRequest{ + NotificationText: msg, + ChannelId: channelID, + RequestTime: timestamppb.Now(), + }) if err != nil { - log.Fatal().Err(err) + log.Error().Err(err) } - log.Info().Msg(fmt.Sprintf("%v", r)) + log.Info().Msg(fmt.Sprintf("%v", response)) } type priceChanStruct struct { @@ -90,82 +140,99 @@ type errorChanStruct struct { err error } -type APISource int +// type APISource int const ( CryptoCompareSource = iota - CoinGeckoSource - CoinCapSource + // CoinGeckoSource + // CoinCapSource ) -// TODO-add more sources -// TODO-do a round robin +// TODO-add more sources. +// TODO-do a round robin. func chooseGetPriceSource() int { return CryptoCompareSource } -func getPrice(name, unit string, - wg *sync.WaitGroup, +func getPrice(ctx context.Context, + name, unit string, + waitGroup *sync.WaitGroup, priceChan chan<- priceChanStruct, - errChan chan<- errorChanStruct) { - - // check price cache - ctx := context.Background() + errChan chan<- errorChanStruct, +) { val, err := rdb.Get(ctx, name+"_price").Float64() + if err != nil { - fmt.Println("price cache miss") source := chooseGetPriceSource() if source == CryptoCompareSource { - getPriceFromCryptoCompare(name, unit, wg, priceChan, errChan) + getPriceFromCryptoCompare(ctx, name, unit, waitGroup, priceChan, errChan) } } else { - fmt.Println("price cache hit ", val) priceChan <- priceChanStruct{name: name, price: val} errChan <- errorChanStruct{hasError: false, err: nil} - wg.Done() + waitGroup.Done() } } +func getPriceFromCryptoCompareErrorHandler( + err error, + name string, + priceChan chan<- priceChanStruct, + errChan chan<- errorChanStruct, +) { + defaultPrice := 0. + priceChan <- priceChanStruct{name: name, price: defaultPrice} + errChan <- errorChanStruct{hasError: true, err: err} + + log.Error().Err(err) +} + func getPriceFromCryptoCompare( + ctx context.Context, name, unit string, wg *sync.WaitGroup, priceChan chan<- priceChanStruct, - errChan chan<- errorChanStruct) { + errChan chan<- errorChanStruct, +) { defer wg.Done() params := "fsym=" + url.QueryEscape(name) + "&" + "tsyms=" + url.QueryEscape(unit) path := cryptocomparePriceURL + params - resp, err := http.Get(path) + + client, err := GetProxiedClient() if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} - errChan <- errorChanStruct{hasError: true, err: err} - log.Error().Err(err) + getPriceFromCryptoCompareErrorHandler(err, name, priceChan, errChan) + return } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, path, nil) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} - errChan <- errorChanStruct{hasError: true, err: err} - log.Error().Err(err) + getPriceFromCryptoCompareErrorHandler(err, name, priceChan, errChan) + + return } - jsonBody := make(map[string]float64) - err = json.Unmarshal(body, &jsonBody) + resp, err := client.Do(req) if err != nil { - priceChan <- priceChanStruct{name: name, price: 0.} - errChan <- errorChanStruct{hasError: true, err: err} - log.Error().Err(err) + getPriceFromCryptoCompareErrorHandler(err, name, priceChan, errChan) + + return } + defer resp.Body.Close() - log.Info().Msg(string(body)) + jsonBody := make(map[string]float64) + + err = json.NewDecoder(resp.Body).Decode(&jsonBody) + if err != nil { + getPriceFromCryptoCompareErrorHandler(err, name, priceChan, errChan) + } // add a price cache - ctx := context.Background() - err = rdb.Set(ctx, name+"_price", jsonBody[unit], time.Duration(*cacheDuration*1000000)).Err() + err = rdb.Set(ctx, name+"_price", jsonBody[unit], time.Duration(*cacheDuration*redisCacheDurationMultiplier)).Err() + if err != nil { log.Error().Err(err) } @@ -174,16 +241,20 @@ func getPriceFromCryptoCompare( errChan <- errorChanStruct{hasError: false, err: nil} } -func PriceHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") - if r.Method != "GET" { - http.Error(w, "Method is not supported.", http.StatusNotFound) +func PriceHandler(writer http.ResponseWriter, request *http.Request) { + writer.Header().Add("Content-Type", "application/json") + + if request.Method != http.MethodGet { + http.Error(writer, "Method is not supported.", http.StatusNotFound) } - addSecureHeaders(&w) + + addSecureHeaders(&writer) var name string + var unit string - params := r.URL.Query() + + params := request.URL.Query() for key, value := range params { switch key { case "name": @@ -191,35 +262,46 @@ func PriceHandler(w http.ResponseWriter, r *http.Request) { case "unit": unit = value[0] default: - log.Error().Err(errors.New("bad parameters for the crypto endpoint.")) + log.Error().Err(errUnknownParam) } } if name == "" || unit == "" { - json.NewEncoder(w).Encode(map[string]interface{}{ + err := json.NewEncoder(writer).Encode(map[string]interface{}{ "err": "query parameters must include name and unit", - "isSuccessful": false}) - log.Error().Err(errors.New("query parameters must include name and unit.")) + "isSuccessful": false, + }) + if err != nil { + log.Error().Err(errIncompParams) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } + return } - var wg sync.WaitGroup + var waitGroup sync.WaitGroup + priceChan := make(chan priceChanStruct, 1) errChan := make(chan errorChanStruct, 1) + defer close(errChan) defer close(priceChan) - wg.Add(1) - // TODO- check cache - go getPrice(name, unit, &wg, priceChan, errChan) - wg.Wait() + waitGroup.Add(1) + + ctx, cancel := context.WithTimeout(request.Context(), getTimeout*time.Second) + defer cancel() + + go getPrice(ctx, name, unit, &waitGroup, priceChan, errChan) + + waitGroup.Wait() select { case err := <-errChan: - if err.hasError != false { + if err.hasError { log.Error().Err(err.err) } default: - log.Error().Err(errors.New("this shouldn't have happened'")) + log.Error().Err(errBadLogic) } var price priceChanStruct @@ -227,28 +309,39 @@ func PriceHandler(w http.ResponseWriter, r *http.Request) { case priceCh := <-priceChan: price = priceCh default: - log.Fatal().Err(errors.New("this shouldnt have happened")) + log.Error().Err(errBadLogic) } - json.NewEncoder(w).Encode(map[string]interface{}{ + err := json.NewEncoder(writer).Encode(map[string]interface{}{ "name": price.name, "price": price.price, "unit": unit, "err": "", - "isSuccessful": true}) + "isSuccessful": true, + }) + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } } func PairHandler(w http.ResponseWriter, r *http.Request) { var err error + w.Header().Add("Content-Type", "application/json") - if r.Method != "GET" { + + if r.Method != http.MethodGet { http.Error(w, "Method is not supported.", http.StatusNotFound) } + addSecureHeaders(&w) var one string + var two string + var multiplier float64 + params := r.URL.Query() for key, value := range params { switch key { @@ -262,55 +355,71 @@ func PairHandler(w http.ResponseWriter, r *http.Request) { log.Fatal().Err(err) } default: - log.Fatal().Err(errors.New("unknown parameters for the pair endpoint.")) + log.Fatal().Err(errUnknownParam) } } if one == "" || two == "" || multiplier == 0. { - log.Error().Err(errors.New("the query must include one()),two and multiplier")) + log.Error().Err(errIncompParams) } - var wg sync.WaitGroup - priceChan := make(chan priceChanStruct, 2) - errChan := make(chan errorChanStruct, 2) + var waitGroup sync.WaitGroup + + priceChan := make(chan priceChanStruct, 2) //nolint: gomnd + errChan := make(chan errorChanStruct, 2) //nolint: gomnd + defer close(priceChan) defer close(errChan) - wg.Add(2) - go getPrice(one, "USD", &wg, priceChan, errChan) - go getPrice(two, "USD", &wg, priceChan, errChan) - wg.Wait() + ctx, cancel := context.WithTimeout(r.Context(), getTimeout*time.Second) + defer cancel() + + waitGroup.Add(2) //nolint: gomnd + + go getPrice(ctx, one, "USD", &waitGroup, priceChan, errChan) + go getPrice(ctx, two, "USD", &waitGroup, priceChan, errChan) + + waitGroup.Wait() for i := 0; i < 2; i++ { select { case err := <-errChan: - if err.hasError != false { + if err.hasError { log.Error().Err(err.err) } default: - log.Fatal().Err(errors.New("this shouldnt have happened")) + log.Error().Err(errBadLogic) } } var priceOne float64 + var priceTwo float64 + for i := 0; i < 2; i++ { select { case price := <-priceChan: if price.name == one { priceOne = price.price } + if price.name == two { priceTwo = price.price } default: - log.Fatal().Err(errors.New("this shouldnt have happened")) + log.Error().Err(errBadLogic) } } ratio := priceOne * multiplier / priceTwo + log.Info().Msg(fmt.Sprintf("%v", ratio)) - json.NewEncoder(w).Encode(map[string]interface{}{"ratio": ratio}) + + err = json.NewEncoder(w).Encode(map[string]interface{}{"ratio": ratio}) + if err != nil { + log.Error().Err(err) + http.Error(w, "internal server error", http.StatusInternalServerError) + } } type alertType struct { @@ -322,98 +431,120 @@ type alertsType struct { Alerts []alertType `json:"alerts"` } -func getAlerts() (alertsType, error) { +func getAlerts() alertsType { var alerts alertsType + ctx := context.Background() keys := rdb.SMembersMap(ctx, "alertkeys") alerts.Alerts = make([]alertType, len(keys.Val())) vals := keys.Val() - i := 0 + alertIndex := 0 + for key := range vals { alert := rdb.Get(ctx, key[6:]) expr, _ := alert.Result() - alerts.Alerts[i].Name = key - alerts.Alerts[i].Expr = expr - i++ + alerts.Alerts[alertIndex].Name = key + alerts.Alerts[alertIndex].Expr = expr + alertIndex++ } - return alerts, nil + return alerts } -func alertManager() { - for { - alerts, err := getAlerts() - if err != nil { - log.Error().Err(err) - return - } - log.Info().Msg(fmt.Sprintf("%v", alerts)) +func alertManagerWorker(alert alertType) { + expression, err := govaluate.NewEvaluableExpression(alert.Expr) + if err != nil { + log.Error().Err(err) + } - for i := range alerts.Alerts { - expression, err := govaluate.NewEvaluableExpression(alerts.Alerts[i].Expr) - if err != nil { - log.Error().Err(err) - continue - } + vars := expression.Vars() + parameters := make(map[string]interface{}, len(vars)) - vars := expression.Vars() - parameters := make(map[string]interface{}, len(vars)) + var waitGroup sync.WaitGroup - var wg sync.WaitGroup - priceChan := make(chan priceChanStruct, len(vars)) - errChan := make(chan errorChanStruct, len(vars)) - defer close(priceChan) - defer close(errChan) - wg.Add(len(vars)) + priceChan := make(chan priceChanStruct, len(vars)) + defer close(priceChan) - for i := range vars { - // TODO-get from cache - go getPrice(vars[i], "USD", &wg, priceChan, errChan) - } - wg.Wait() - - for i := 0; i < len(vars); i++ { - select { - case err := <-errChan: - if err.hasError != false { - log.Printf(err.err.Error()) - } - default: - log.Error().Err(errors.New("this shouldnt have happened")) - } - } + errChan := make(chan errorChanStruct, len(vars)) + defer close(errChan) - for i := 0; i < len(vars); i++ { - select { - case price := <-priceChan: - parameters[price.name] = price.price - default: - log.Error().Err(errors.New("this shouldnt have happened")) - } - } + ctx, cancel := context.WithTimeout(context.Background(), getTimeout*time.Second) + defer cancel() - log.Info().Msg(fmt.Sprintf("parameters: %v", parameters)) - result, err := expression.Evaluate(parameters) - if err != nil { - log.Error().Err(err) - } + waitGroup.Add(len(vars)) + + for i := range vars { + go getPrice(ctx, vars[i], "USD", &waitGroup, priceChan, errChan) + } + + waitGroup.Wait() - var resultBool bool - log.Info().Msg(fmt.Sprintf("result: %v", result)) - resultBool = result.(bool) - if resultBool == true { - token := os.Getenv(TELEGRAM_BOT_TOKEN_ENV_VAR) - msgText := "notification " + alerts.Alerts[i].Expr + " has been triggered" - tokenInt, err := strconv.ParseInt(token[1:len(token)-1], 10, 64) - if err != nil { - log.Fatal().Err(err) - } - sendToTg("telebot:8000", msgText, tokenInt) + for i := 0; i < len(vars); i++ { + select { + case err := <-errChan: + if err.hasError { + log.Printf(err.err.Error()) } + default: + log.Error().Err(errBadLogic) + } + } + + for i := 0; i < len(vars); i++ { + select { + case price := <-priceChan: + parameters[price.name] = price.price + default: + log.Error().Err(errBadLogic) + } + } + + log.Info().Msg(fmt.Sprintf("parameters: %v", parameters)) + + result, err := expression.Evaluate(parameters) + if err != nil { + log.Error().Err(err) + } + + var resultBool bool + + log.Info().Msg(fmt.Sprintf("result: %v", result)) + + resultBool, ok := result.(bool) + if !ok { + log.Error().Err(errFailedTypeAssertion) + + return + } + + if !resultBool { + return + } + + token := os.Getenv(telegramBotTokenEnvVar) + msgText := "notification " + alert.Expr + " has been triggered" + + tokenInt, err := strconv.ParseInt(token[1:len(token)-1], 10, 64) + + if err == nil { + log.Error().Err(err) + } + + sendToTg("telebot:8000", msgText, tokenInt) +} + +func alertManager(alertsCheckInterval int64) { + for { + alerts := getAlerts() + + log.Info().Msg(fmt.Sprintf("%v", alerts)) + + for alertIndex := range alerts.Alerts { + go alertManagerWorker(alerts.Alerts[alertIndex]) } - time.Sleep(time.Second * time.Duration(*alertsCheckInterval)) + time.Sleep(time.Second * time.Duration(alertsCheckInterval)) } } @@ -422,91 +553,136 @@ type addAlertJSONType struct { Expr string `json:"expr"` } -func (this AlertHandler) HandleAlertPost(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") - bodyBytes, err := ioutil.ReadAll(r.Body) +func (alertHandler AlertHandler) HandleAlertPost(writer http.ResponseWriter, request *http.Request) { + var bodyJSON addAlertJSONType + + writer.Header().Add("Content-Type", "application/json") + + err := json.NewDecoder(request.Body).Decode(&bodyJSON) if err != nil { log.Printf(err.Error()) - } - var bodyJSON addAlertJSONType - json.Unmarshal(bodyBytes, &bodyJSON) + err := json.NewEncoder(writer).Encode(map[string]interface{}{ + "isSuccessful": false, + "error": "not all parameters are valid.", + }) + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } + } if bodyJSON.Name == "" || bodyJSON.Expr == "" { - json.NewEncoder(w).Encode(map[string]interface{}{ + err := json.NewEncoder(writer).Encode(map[string]interface{}{ "isSuccessful": false, - "error": "not all parameters are valid."}) - log.Fatal().Err(errors.New("not all parameters are valid.")) + "error": "not all parameters are valid.", + }) + if err != nil { + log.Error().Err(errFailedUnmarshall) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } + return } - ctx := context.Background() + ctx, cancel := context.WithTimeout(request.Context(), redisContextTimeout*time.Second) + defer cancel() + key := "alert:" + bodyJSON.Name - this.rdb.Set(ctx, bodyJSON.Name, bodyJSON.Expr, 0) - this.rdb.SAdd(ctx, "alertkeys", key) - json.NewEncoder(w).Encode(map[string]interface{}{ + alertHandler.rdb.Set(ctx, bodyJSON.Name, bodyJSON.Expr, 0) + alertHandler.rdb.SAdd(ctx, "alertkeys", key) + + err = json.NewEncoder(writer).Encode(map[string]interface{}{ "isSuccessful": true, - "error": ""}) + "error": "", + }) + + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } } -func (this AlertHandler) HandleAlertDelete(w http.ResponseWriter, r *http.Request) { - var Id string - w.Header().Add("Content-Type", "application/json") - params := r.URL.Query() +func (alertHandler AlertHandler) HandleAlertDelete(writer http.ResponseWriter, request *http.Request) { + var identifier string + + writer.Header().Add("Content-Type", "application/json") + + params := request.URL.Query() + for key, value := range params { switch key { case "key": - Id = value[0] + identifier = value[0] default: - log.Error().Err(errors.New("bad parameters for the crypto endpoint.")) + log.Error().Err(errUnknownParam) } } - if Id == "" { - json.NewEncoder(w).Encode(map[string]interface{}{ + if identifier == "" { + err := json.NewEncoder(writer).Encode(map[string]interface{}{ "isSuccessful": false, - "error": "Id parameter is not valid."}) - log.Fatal().Err(errors.New("not all parameters are valid.")) + "error": "Id parameter is not valid.", + }) + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } + return } - ctx := context.Background() + ctx, cancel := context.WithTimeout(request.Context(), redisContextTimeout*time.Second) + defer cancel() - this.rdb.Del(ctx, Id) - setKey := "alert:" + Id - this.rdb.SRem(ctx, "alertkeys", setKey) + alertHandler.rdb.Del(ctx, identifier) + setKey := "alert:" + identifier + alertHandler.rdb.SRem(ctx, "alertkeys", setKey) log.Printf(setKey) - json.NewEncoder(w).Encode(struct { + err := json.NewEncoder(writer).Encode(struct { IsSuccessful bool `json:"isSuccessful"` Err string `json:"err"` }{IsSuccessful: true, Err: ""}) + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } } -func (this AlertHandler) HandleAlertGet(w http.ResponseWriter, r *http.Request) { - var Id string - w.Header().Add("Content-Type", "application/json") - params := r.URL.Query() +func (alertHandler AlertHandler) HandleAlertGet(writer http.ResponseWriter, request *http.Request) { + var identifier string + + writer.Header().Add("Content-Type", "application/json") + + params := request.URL.Query() for key, value := range params { switch key { case "key": - Id = value[0] + identifier = value[0] default: - log.Error().Err(errors.New("bad parameters for the crypto endpoint.")) + log.Error().Err(errUnknownParam) } } - if Id == "" { - json.NewEncoder(w).Encode(map[string]interface{}{ + if identifier == "" { + err := json.NewEncoder(writer).Encode(map[string]interface{}{ "isSuccessful": false, - "error": "Id parameter is not valid."}) - log.Fatal().Err(errors.New("not all parameters are valid.")) + "error": "Id parameter is not valid.", + }) + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) + } + return } - ctx := context.Background() + ctx, cancel := context.WithTimeout(request.Context(), redisContextTimeout*time.Second) + defer cancel() + + redisResult := alertHandler.rdb.Get(ctx, identifier) - redisResult := this.rdb.Get(ctx, Id) redisResultString, err := redisResult.Result() if err != nil { log.Err(err) @@ -519,111 +695,67 @@ func (this AlertHandler) HandleAlertGet(w http.ResponseWriter, r *http.Request) ErrorString = err.Error() } - w.Header().Add("Content-Type", "application/json") + writer.Header().Add("Content-Type", "application/json") - json.NewEncoder(w).Encode(struct { + err = json.NewEncoder(writer).Encode(struct { IsSuccessful bool `json:"isSuccessful"` Error string `json:"error"` Key string `json:"key"` Expr string `json:"expr"` - }{IsSuccessful: true, Error: ErrorString, Key: Id, Expr: redisResultString}) -} + }{IsSuccessful: true, Error: ErrorString, Key: identifier, Expr: redisResultString}) -func alertHandler(w http.ResponseWriter, r *http.Request) { - addSecureHeaders(&w) - alertHandler := AlertHandler{rdb: rdb} - if r.Method == "POST" || r.Method == "PUT" || r.Method == "PATCH" { - alertHandler.HandleAlertPost(w, r) - } else if r.Method == "DELETE" { - alertHandler.HandleAlertDelete(w, r) - } else if r.Method == "GET" { - alertHandler.HandleAlertGet(w, r) - } else { - http.Error(w, "Method is not supported.", http.StatusNotFound) + if err != nil { + log.Error().Err(err) + http.Error(writer, "internal server error", http.StatusInternalServerError) } } -func exHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") - addSecureHeaders(&w) - if r.Method != "GET" { - http.Error(w, "Method is not supported.", http.StatusNotFound) - } - - apiKey := os.Getenv(CHANGELLY_API_KEY_ENV_VAR) - apiSecret := os.Getenv(CHANGELLY_API_SECRET_ENV_VAR) +func alertHandler(writer http.ResponseWriter, request *http.Request) { + addSecureHeaders(&writer) - body := struct { - Jsonrpc string `json:"jsonrpc"` - Id string `json:"id"` - Method string `json:"method"` - Params []string `json:"params"` - }{ - Jsonrpc: "2.0", - Id: "test", - Method: "getCurrencies", - Params: nil} + alertHandler := AlertHandler{rdb: rdb} - bodyJSON, err := json.Marshal(body) - if err != nil { - log.Error().Err(err) + switch request.Method { + case http.MethodPost: + alertHandler.HandleAlertPost(writer, request) + case http.MethodPut: + alertHandler.HandleAlertPost(writer, request) + case http.MethodPatch: + alertHandler.HandleAlertPost(writer, request) + case http.MethodDelete: + alertHandler.HandleAlertDelete(writer, request) + case http.MethodGet: + alertHandler.HandleAlertGet(writer, request) + default: + http.Error(writer, "Method is not supported.", http.StatusNotFound) } +} - secretBytes := []byte(apiSecret[1 : len(apiSecret)-1]) - mac := hmac.New(sha512.New, secretBytes) - mac.Write(bodyJSON) - - client := &http.Client{} - req, err := http.NewRequest("POST", changellyURL, bytes.NewReader(bodyJSON)) - if err != nil { - log.Error().Err(err) - } +func healthHandler(writer http.ResponseWriter, request *http.Request) { + var RedisError string - macDigest := hex.EncodeToString(mac.Sum(nil)) - req.Header.Add("Content-Type", "application/json") - req.Header.Add("api-key", apiKey[1:len(apiKey)-1]) - req.Header.Add("sign", macDigest) + var HivedError string - resp, err := client.Do(req) - if err != nil { - log.Error().Err(err) - } - defer resp.Body.Close() + var IsRedisOk bool - responseBody, err := ioutil.ReadAll(resp.Body) - log.Printf(string(responseBody)) + IsHivedOk := true - responseUnmarshalled := struct { - Jsonrpc string `json:"jsonrpc"` - Id string `json:"id"` - Result []string `json:"result"` - }{} + addSecureHeaders(&writer) + writer.Header().Add("Content-Type", "application/json") - err = json.Unmarshal(responseBody, &responseUnmarshalled) - if err != nil { - log.Error().Err(err) + if request.Method != http.MethodGet { + http.Error(writer, "Method is not supported.", http.StatusNotFound) } - json.NewEncoder(w).Encode(responseUnmarshalled) -} - -func healthHandler(w http.ResponseWriter, r *http.Request) { - var RedisError string - var HivedError string - IsHivedOk := true - var IsRedisOk bool + ctx, cancel := context.WithTimeout(request.Context(), pingTimeout*time.Second) + defer cancel() - addSecureHeaders(&w) - w.Header().Add("Content-Type", "application/json") - if r.Method != "GET" { - http.Error(w, "Method is not supported.", http.StatusNotFound) - } + pingResponse := rdb.Ping(ctx) - pingCtx := context.Background() - pingResponse := rdb.Ping(pingCtx) pingResponseResult, err := pingResponse.Result() if err != nil { log.Err(err) + IsRedisOk = false RedisError = err.Error() } else { @@ -636,63 +768,69 @@ func healthHandler(w http.ResponseWriter, r *http.Request) { } } - w.WriteHeader(http.StatusOK) + writer.WriteHeader(http.StatusOK) + + err = json.NewEncoder(writer).Encode(map[string]interface{}{ + "isHivedOk": IsHivedOk, + "hivedError": HivedError, + "isRedisOk": IsRedisOk, + "redisError": RedisError, + }) - json.NewEncoder(w).Encode(struct { - IsHivedOk bool `json:"isHivedOk"` - HivedError string `json:"hivedError"` - IsRedisOk bool `json:"isRedisOk"` - RedisError string `json:"redisError"` - }{IsHivedOk: IsHivedOk, HivedError: HivedError, IsRedisOk: IsRedisOk, RedisError: RedisError}) + if request.Method != http.MethodGet { + http.Error(writer, "internal server error", http.StatusInternalServerError) + log.Error().Err(err) + } } -func robotsHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "text/plain") - addSecureHeaders(&w) - json.NewEncoder(w).Encode(struct { - UserAgents string `json:"User-Agents"` - Disallow string `json:"Disallow"` - }{"*", "/"}) +func robotsHandler(writer http.ResponseWriter, r *http.Request) { + writer.Header().Add("Content-Type", "text/plain") + addSecureHeaders(&writer) + + _, err := writer.Write([]byte("User-Agents: *\nDisallow: /\n")) + if err != nil { + log.Error().Err(err) + } + + http.Error(writer, "internal server error", http.StatusInternalServerError) } -func startServer(gracefulWait time.Duration) { - r := mux.NewRouter() +func startServer(gracefulWait time.Duration, flagPort string) { + router := mux.NewRouter() + cfg := &tls.Config{ MinVersion: tls.VersionTLS13, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, - CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_GCM_SHA384, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - }, } + srv := &http.Server{ - Addr: "0.0.0.0:" + *flagPort, - WriteTimeout: time.Second * 15, - ReadTimeout: time.Second * 15, - Handler: r, + Addr: "0.0.0.0:" + flagPort, + WriteTimeout: time.Second * serverTLSWriteTimeout, + ReadTimeout: time.Second * serverTLSReadTimeout, + Handler: router, TLSConfig: cfg, } - r.HandleFunc("/crypto/v1/health", healthHandler) - r.HandleFunc("/crypto/v1/price", PriceHandler) - r.HandleFunc("/crypto/v1/pair", PairHandler) - r.HandleFunc("/crypto/v1/alert", alertHandler) - r.HandleFunc("/crypto/v1/ex", exHandler) - r.HandleFunc("/crypto/v1/robots.txt", robotsHandler) + + router.HandleFunc("/crypto/v1/health", healthHandler) + router.HandleFunc("/crypto/v1/price", PriceHandler) + router.HandleFunc("/crypto/v1/pair", PairHandler) + router.HandleFunc("/crypto/v1/alert", alertHandler) + router.HandleFunc("/crypto/v1/robots.txt", robotsHandler) go func() { var certPath, keyPath string - if os.Getenv(SERVER_DEPLOYMENT_TYPE) == "deployment" { + + switch os.Getenv(serverDeploymentType) { + case "deployment": certPath = "/certs/fullchain1.pem" keyPath = "/certs/privkey1.pem" - } else if os.Getenv(SERVER_DEPLOYMENT_TYPE) == "test" { + case "test": certPath = "/certs/server.cert" keyPath = "/certs/server.key" - } else { - log.Fatal().Err(errors.New(fmt.Sprintf("unknown deployment kind: %s", SERVER_DEPLOYMENT_TYPE))) + default: + log.Fatal().Err(errUnknownDeploymentKind) } + if err := srv.ListenAndServeTLS(certPath, keyPath); err != nil { log.Fatal().Err(err) } @@ -702,10 +840,15 @@ func startServer(gracefulWait time.Duration) { signal.Notify(c, os.Interrupt) <-c + ctx, cancel := context.WithTimeout(context.Background(), gracefulWait) defer cancel() - srv.Shutdown(ctx) - log.Info().Msg("gracefully shut down the server") + + if err := srv.Shutdown(ctx); err != nil { + log.Error().Err(err) + } else { + log.Info().Msg("gracefully shut down the server") + } } func setupLogging() { @@ -714,7 +857,21 @@ func setupLogging() { func main() { var gracefulWait time.Duration - flag.DurationVar(&gracefulWait, "gracefulwait", time.Second*15, "the duration to wait during the graceful shutdown") + + flagPort := flag.String("port", "8008", "determined the port the sercice runs on") + redisAddress := flag.String("redisaddress", "redis:6379", "determines the address of the redis instance") + redisPassword := flag.String("redispassword", "", "determines the password of the redis db") + redisDB := flag.Int64("redisdb", 0, "determines the db number") + alertsCheckInterval := flag.Int64( + "alertinterval", + alertCheckIntervalDefault, + "in seconds, the amount of time between alert checks") + + flag.DurationVar( + &gracefulWait, "gracefulwait", + time.Second*defaultGracefulShutdown, + "the duration to wait during the graceful shutdown") + flag.Parse() rdb = redis.NewClient(&redis.Options{ @@ -725,6 +882,8 @@ func main() { defer rdb.Close() setupLogging() - go alertManager() - startServer(gracefulWait) + + go alertManager(*alertsCheckInterval) + + startServer(gracefulWait, *flagPort) } diff --git a/hived/hived_test.go b/hived/hived_test.go index 96bdd1e..bef135d 100644 --- a/hived/hived_test.go +++ b/hived/hived_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "encoding/json" + "flag" "fmt" "io/ioutil" "net/http" @@ -16,6 +17,12 @@ const ( endpoint = "https://api.terminaldweller.com/crypto/v1" ) +var ( + redisAddress = flag.String("redisaddress", "redis:6379", "determines the address of the redis instance") + redisPassword = flag.String("redispassword", "", "determines the password of the redis db") + redisDB = flag.Int64("redisdb", 0, "determines the db number") +) + func errorHandler(recorder *httptest.ResponseRecorder, t *testing.T, err error) { if err != nil { t.Errorf(err.Error()) @@ -27,7 +34,7 @@ func errorHandler(recorder *httptest.ResponseRecorder, t *testing.T, err error) } func TestPriceHandler(t *testing.T) { - req, err := http.NewRequest("GET", endpoint+"/price?name=BTC&unit=USD", nil) + req, err := http.NewRequest(http.MethodGet, endpoint+"/price?name=BTC&unit=USD", nil) recorder := httptest.NewRecorder() PriceHandler(recorder, req) errorHandler(recorder, t, err) @@ -49,7 +56,7 @@ func TestPriceHandler(t *testing.T) { } func TestPairHandler(t *testing.T) { - req, err := http.NewRequest("GET", endpoint+"/pair?one=ETH&two=CAKE&multiplier=4.0", nil) + req, err := http.NewRequest(http.MethodGet, endpoint+"/pair?one=ETH&two=CAKE&multiplier=4.0", nil) recorder := httptest.NewRecorder() PairHandler(recorder, req) errorHandler(recorder, t, err) @@ -78,7 +85,7 @@ func TestAlertHandlerPhase1(t *testing.T) { if err != nil { fmt.Println(err.Error()) } - req, err := http.NewRequest("POST", endpoint+"/alert", bytes.NewBuffer(postData)) + req, err := http.NewRequest(http.MethodPost, endpoint+"/alert", bytes.NewBuffer(postData)) req.Header.Set("Content-Type", "application/json") recorder := httptest.NewRecorder() alertHandler := AlertHandler{rdb: rdb} @@ -108,7 +115,7 @@ func TestAlertHandlerPhase2(t *testing.T) { }) defer rdb.Close() - req, err := http.NewRequest("GET", endpoint+"/alert?key=alertTest", nil) + req, err := http.NewRequest(http.MethodGet, endpoint+"/alert?key=alertTest", nil) recorder := httptest.NewRecorder() alertHandler := AlertHandler{rdb: rdb} alertHandler.HandleAlertGet(recorder, req) @@ -142,7 +149,7 @@ func TestAlertHandlerPhase3(t *testing.T) { if err != nil { fmt.Println(err.Error()) } - req, err := http.NewRequest("PUT", endpoint+"/alert", bytes.NewBuffer(postData)) + req, err := http.NewRequest(http.MethodPut, endpoint+"/alert", bytes.NewBuffer(postData)) req.Header.Set("Content-Type", "application/json") recorder := httptest.NewRecorder() alertHandler := AlertHandler{rdb: rdb} @@ -158,7 +165,7 @@ func TestAlertHandlerPhase4(t *testing.T) { }) defer rdb.Close() - req, err := http.NewRequest("GET", endpoint+"/alert?key=alertTest", nil) + req, err := http.NewRequest(http.MethodGet, endpoint+"/alert?key=alertTest", nil) recorder := httptest.NewRecorder() alertHandler := AlertHandler{rdb: rdb} alertHandler.HandleAlertGet(recorder, req) @@ -187,7 +194,7 @@ func TestAlertHandlerPhase5(t *testing.T) { }) defer rdb.Close() - req, err := http.NewRequest("DELETE", endpoint+"/alert?key=alertTest", nil) + req, err := http.NewRequest(http.MethodDelete, endpoint+"/alert?key=alertTest", nil) recorder := httptest.NewRecorder() alertHandler := AlertHandler{rdb: rdb} alertHandler.HandleAlertGet(recorder, req) @@ -216,7 +223,7 @@ func TestAlertHandlerPhase6(t *testing.T) { }) defer rdb.Close() - req, err := http.NewRequest("GET", endpoint+"/alert?key=alertTest", nil) + req, err := http.NewRequest(http.MethodGet, endpoint+"/alert?key=alertTest", nil) recorder := httptest.NewRecorder() alertHandler := AlertHandler{rdb: rdb} alertHandler.HandleAlertGet(recorder, req) |