diff options
author | bloodstalker <thabogre@gmail.com> | 2018-09-03 16:12:29 +0000 |
---|---|---|
committer | bloodstalker <thabogre@gmail.com> | 2018-09-03 16:12:29 +0000 |
commit | f9356b9a938a8198d0c94ac4f5cabd829863176c (patch) | |
tree | eb9085ef5947be662a4e9dda0884b945b98be92b /cnn.py | |
parent | update (diff) | |
download | seer-f9356b9a938a8198d0c94ac4f5cabd829863176c.tar.gz seer-f9356b9a938a8198d0c94ac4f5cabd829863176c.zip |
update
Diffstat (limited to 'cnn.py')
-rwxr-xr-x | cnn.py | 10 |
1 files changed, 5 insertions, 5 deletions
@@ -217,7 +217,7 @@ def lstm_type_cnn_1(symbol_str, kind): model.compile(loss='mse', optimizer='adam') model.fit(training_datas, training_labels, batch_size=batch_size,validation_data=(validation_datas,validation_labels), epochs = epochs, callbacks=[CSVLogger(output_file_name+'.csv', append=True),ModelCheckpoint(output_file_name+'-{epoch:02d}-{val_loss:.5f}.hdf5', monitor='val_loss', verbose=1,mode='min')]) -def load_cnn_type_1(symbol_str): +def load_cnn_type_1(symbol_str, vis_year, vis_month): df, original_df, time_stamps = getData(symbol_str) Scaler(df, original_df, time_stamps, symbol_str) """ @@ -260,7 +260,7 @@ def load_cnn_type_1(symbol_str): # model.add(LeakyReLU()) model.add(Dropout(0.25)) model.add(Conv1D( strides=4, filters=nb_features, kernel_size=16)) - model.load_weights("cnn/" + symbol_str + "_CNN_2_relu-76-0.00056.hdf5") + model.load_weights("cnn/" + symbol_str + "_CNN_2_relu-76-0.00036.hdf5") model.compile(loss='mse', optimizer='adam') predicted = model.predict(validation_datas) @@ -287,8 +287,8 @@ def load_cnn_type_1(symbol_str): prediction_df['times'] = validation_output_times prediction_df['value'] = predicted_inverted - prediction_df = prediction_df.loc[(prediction_df["times"].dt.year == 2017 )&(prediction_df["times"].dt.month > 7 ),: ] - ground_true_df = ground_true_df.loc[(ground_true_df["times"].dt.year == 2017 )&(ground_true_df["times"].dt.month > 7 ),:] + prediction_df = prediction_df.loc[(prediction_df["times"].dt.year == vis_year )&(prediction_df["times"].dt.month > vis_month ),: ] + ground_true_df = ground_true_df.loc[(ground_true_df["times"].dt.year == vis_year )&(ground_true_df["times"].dt.month > vis_month ),:] plt.figure(figsize=(20,10)) plt.plot(ground_true_df.times,ground_true_df.value, label = 'Actual') @@ -302,7 +302,7 @@ def premain(argparser): #here #cnn_type_1("ETH") #lstm_type_cnn_1("ETH", "GRU") - load_cnn_type_1("ETH") + load_cnn_type_1("ETH", 2018, 4) def main(): argparser = Argparser() |