From f9356b9a938a8198d0c94ac4f5cabd829863176c Mon Sep 17 00:00:00 2001 From: bloodstalker Date: Mon, 3 Sep 2018 20:42:29 +0430 Subject: update --- cnn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cnn.py b/cnn.py index 7cb1493..23e450f 100755 --- a/cnn.py +++ b/cnn.py @@ -217,7 +217,7 @@ def lstm_type_cnn_1(symbol_str, kind): model.compile(loss='mse', optimizer='adam') model.fit(training_datas, training_labels, batch_size=batch_size,validation_data=(validation_datas,validation_labels), epochs = epochs, callbacks=[CSVLogger(output_file_name+'.csv', append=True),ModelCheckpoint(output_file_name+'-{epoch:02d}-{val_loss:.5f}.hdf5', monitor='val_loss', verbose=1,mode='min')]) -def load_cnn_type_1(symbol_str): +def load_cnn_type_1(symbol_str, vis_year, vis_month): df, original_df, time_stamps = getData(symbol_str) Scaler(df, original_df, time_stamps, symbol_str) """ @@ -260,7 +260,7 @@ def load_cnn_type_1(symbol_str): # model.add(LeakyReLU()) model.add(Dropout(0.25)) model.add(Conv1D( strides=4, filters=nb_features, kernel_size=16)) - model.load_weights("cnn/" + symbol_str + "_CNN_2_relu-76-0.00056.hdf5") + model.load_weights("cnn/" + symbol_str + "_CNN_2_relu-76-0.00036.hdf5") model.compile(loss='mse', optimizer='adam') predicted = model.predict(validation_datas) @@ -287,8 +287,8 @@ def load_cnn_type_1(symbol_str): prediction_df['times'] = validation_output_times prediction_df['value'] = predicted_inverted - prediction_df = prediction_df.loc[(prediction_df["times"].dt.year == 2017 )&(prediction_df["times"].dt.month > 7 ),: ] - ground_true_df = ground_true_df.loc[(ground_true_df["times"].dt.year == 2017 )&(ground_true_df["times"].dt.month > 7 ),:] + prediction_df = prediction_df.loc[(prediction_df["times"].dt.year == vis_year )&(prediction_df["times"].dt.month > vis_month ),: ] + ground_true_df = ground_true_df.loc[(ground_true_df["times"].dt.year == vis_year )&(ground_true_df["times"].dt.month > vis_month ),:] plt.figure(figsize=(20,10)) plt.plot(ground_true_df.times,ground_true_df.value, label = 'Actual') @@ -302,7 +302,7 @@ def premain(argparser): #here #cnn_type_1("ETH") #lstm_type_cnn_1("ETH", "GRU") - load_cnn_type_1("ETH") + load_cnn_type_1("ETH", 2018, 4) def main(): argparser = Argparser() -- cgit v1.2.3