aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rwxr-xr-xcnn.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/cnn.py b/cnn.py
index 7cb1493..23e450f 100755
--- a/cnn.py
+++ b/cnn.py
@@ -217,7 +217,7 @@ def lstm_type_cnn_1(symbol_str, kind):
model.compile(loss='mse', optimizer='adam')
model.fit(training_datas, training_labels, batch_size=batch_size,validation_data=(validation_datas,validation_labels), epochs = epochs, callbacks=[CSVLogger(output_file_name+'.csv', append=True),ModelCheckpoint(output_file_name+'-{epoch:02d}-{val_loss:.5f}.hdf5', monitor='val_loss', verbose=1,mode='min')])
-def load_cnn_type_1(symbol_str):
+def load_cnn_type_1(symbol_str, vis_year, vis_month):
df, original_df, time_stamps = getData(symbol_str)
Scaler(df, original_df, time_stamps, symbol_str)
"""
@@ -260,7 +260,7 @@ def load_cnn_type_1(symbol_str):
# model.add(LeakyReLU())
model.add(Dropout(0.25))
model.add(Conv1D( strides=4, filters=nb_features, kernel_size=16))
- model.load_weights("cnn/" + symbol_str + "_CNN_2_relu-76-0.00056.hdf5")
+ model.load_weights("cnn/" + symbol_str + "_CNN_2_relu-76-0.00036.hdf5")
model.compile(loss='mse', optimizer='adam')
predicted = model.predict(validation_datas)
@@ -287,8 +287,8 @@ def load_cnn_type_1(symbol_str):
prediction_df['times'] = validation_output_times
prediction_df['value'] = predicted_inverted
- prediction_df = prediction_df.loc[(prediction_df["times"].dt.year == 2017 )&(prediction_df["times"].dt.month > 7 ),: ]
- ground_true_df = ground_true_df.loc[(ground_true_df["times"].dt.year == 2017 )&(ground_true_df["times"].dt.month > 7 ),:]
+ prediction_df = prediction_df.loc[(prediction_df["times"].dt.year == vis_year )&(prediction_df["times"].dt.month > vis_month ),: ]
+ ground_true_df = ground_true_df.loc[(ground_true_df["times"].dt.year == vis_year )&(ground_true_df["times"].dt.month > vis_month ),:]
plt.figure(figsize=(20,10))
plt.plot(ground_true_df.times,ground_true_df.value, label = 'Actual')
@@ -302,7 +302,7 @@ def premain(argparser):
#here
#cnn_type_1("ETH")
#lstm_type_cnn_1("ETH", "GRU")
- load_cnn_type_1("ETH")
+ load_cnn_type_1("ETH", 2018, 4)
def main():
argparser = Argparser()