aboutsummaryrefslogtreecommitdiffstats
path: root/tfann.py
blob: aef226ac8528fe8b7b7cc294bc30cc8cd44d551c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/python3
# _*_ coding=utf-8 _*_
# original source-https://nicholastsmith.wordpress.com/2017/11/13/cryptocurrency-price-prediction-using-deep-learning-in-tensorflow/

import argparse
import code
import readline
import signal
import sys
from TFANN import ANNR
import numpy as np
import os
import pandas as pd
import urllib.request
import matplotlib.pyplot as mpl

def SigHandler_SIGINT(signum, frame):
    print()
    sys.exit(0)

class Argparser(object):
    def __init__(self):
        parser = argparse.ArgumentParser()
        parser.add_argument("--string", type=str, help="string")
        parser.add_argument("--bool", action="store_true", help="bool", default=False)
        parser.add_argument("--dbg", action="store_true", help="debug", default=False)
        self.args = parser.parse_args()

def GetAPIUrl(cur, sts = 1420070400):
    return 'https://poloniex.com/public?command=returnChartData&currencyPair=USDT_{:s}&start={:d}&end=9999999999&period=7200'.format(cur, sts)

def GetCurDF(cur, fp):
    openUrl = urllib.request.urlopen(GetAPIUrl(cur))
    r = openUrl.read()
    openUrl.close()
    df = pd.read_json(r.decode())
    df['date'] = df['date'].astype(np.int64) // 1000000000
    print(df.head())
    return df

class PastSampler:
    def __init__(self, N, K):
        self.K = K
        self.N = N

    def transform(self, A, Y = None):
        M = self.N + self.K     #Number of samples per row (sample + target)
        #Matrix of sample indices like: {{1, 2..., M}, {2, 3, ..., M + 1}}
        I = np.arange(M) + np.arange(A.shape[0] - M + 1).reshape(-1, 1)
        B = A[I].reshape(-1, M * A.shape[1], *A.shape[2:])
        ci = self.N * A.shape[1]    #Number of features per sample
        return B[:, :ci], B[:, ci:] #Sample matrix, Target matrix

def tfann_type_1():
    #%%Path to store cached currency data
    datPath = 'CurDat/'
    if not os.path.exists(datPath):
        os.mkdir(datPath)
    #Different cryptocurrency types
    cl = ['BTC', 'LTC', 'ETH', 'XMR']
    #Columns of price data to use
    CN = ['close', 'high', 'low', 'open', 'volume']
    #Store data frames for each of above types
    D = []
    for ci in cl:
        dfp = os.path.join(datPath, ci + '.csv')
        try:
            df = pd.read_csv(dfp, sep = ',')
        except FileNotFoundError:
            df = GetCurDF(ci, dfp)
        D.append(df)
    #%%Only keep range of data that is common to all currency types
    cr = min(Di.shape[0] for Di in D)
    for i in range(len(cl)):
        D[i] = D[i][(D[i].shape[0] - cr):]
    #%%Features are channels
    C = np.hstack((Di[CN] for Di in D))[:, None, :]
    HP = 16                 #Holdout period
    A = C[0:-HP]
    SV = A.mean(axis = 0)   #Scale vector
    C /= SV                 #Basic scaling of data
    #%%Make samples of temporal sequences of pricing data (channel)
    NPS, NFS = 256, 16         #Number of past and future samples
    ps = PastSampler(NPS, NFS)
    B, Y = ps.transform(A)
    #%%Architecture of the neural network
    NC = B.shape[2]
    #2 1-D conv layers with relu followed by 1-d conv output layer
    ns = [('C1d', [8, NC, NC * 2], 4), ('AF', 'relu'),
          ('C1d', [8, NC * 2, NC * 2], 2), ('AF', 'relu'),
          ('C1d', [8, NC * 2, NC], 2)]
    #Create the neural network in TensorFlow
    cnnr = ANNR(B[0].shape, ns, batchSize = 32, learnRate = 2e-5,
                maxIter = 64, reg = 1e-5, tol = 1e-2, verbose = True)
    cnnr.fit(B, Y)
    PTS = []                        #Predicted time sequences
    P, YH = B[[-1]], Y[[-1]]        #Most recent time sequence
    for i in range(HP // NFS):  #Repeat prediction
        P = np.concatenate([P[:, NFS:], YH], axis = 1)
        YH = cnnr.predict(P)
        PTS.append(YH)
    PTS = np.hstack(PTS).transpose((1, 0, 2))
    A = np.vstack([A, PTS]) #Combine predictions with original data
    A = np.squeeze(A) * SV  #Remove unittime dimension and rescale
    C = np.squeeze(C) * SV
    nt = 4
    PF = cnnr.PredictFull(B[:nt])
    for i in range(nt):
        fig, ax = mpl.subplots(1, 4, figsize = (16 / 1.24, 10 / 1.25))
        ax[0].plot(PF[0][i])
        ax[0].set_title('Input')
        ax[1].plot(PF[2][i])
        ax[1].set_title('Layer 1')
        ax[2].plot(PF[4][i])
        ax[2].set_title('Layer 2')
        ax[3].plot(PF[5][i])
        ax[3].set_title('Output')
        fig.text(0.5, 0.06, 'Time', ha='center')
        fig.text(0.06, 0.5, 'Activation', va='center', rotation='vertical')
        mpl.show()
    CI = list(range(C.shape[0]))
    AI = list(range(C.shape[0] + PTS.shape[0] - HP))
    NDP = PTS.shape[0] #Number of days predicted
    for i, cli in enumerate(cl):
        fig, ax = mpl.subplots(figsize = (16 / 1.5, 10 / 1.5))
        hind = i * len(CN) + CN.index('high')
        ax.plot(CI[-4 * HP:], C[-4 * HP:, hind], label = 'Actual')
        ax.plot(AI[-(NDP + 1):], A[-(NDP + 1):, hind], '--', label = 'Prediction')
        ax.legend(loc = 'upper left')
        ax.set_title(cli + ' (High)')
        ax.set_ylabel('USD')
        ax.set_xlabel('Time')
        ax.axes.xaxis.set_ticklabels([])
        mpl.show()

# write code here
def premain(argparser):
    signal.signal(signal.SIGINT, SigHandler_SIGINT)
    #here
    tfann_type_1()

def main():
    argparser = Argparser()
    if argparser.args.dbg:
        try:
            premain(argparser)
        except Exception as e:
            print(e.__doc__)
            if e.message: print(e.message)
            variables = globals().copy()
            variables.update(locals())
            shell = code.InteractiveConsole(variables)
            shell.interact(banner="DEBUG REPL")
    else:
        premain(argparser)

if __name__ == "__main__":
    main()