From d8ef0b60f175125d84f5c82e2ecbcaad4243035d Mon Sep 17 00:00:00 2001 From: Manthan Thakker Date: Sat, 17 May 2025 00:06:49 -0700 Subject: [PATCH] Fix evaluation functions to reshape test data --- PredictStockPricesRNN.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/PredictStockPricesRNN.py b/PredictStockPricesRNN.py index 215c050..1d23ec5 100644 --- a/PredictStockPricesRNN.py +++ b/PredictStockPricesRNN.py @@ -55,16 +55,24 @@ def train_rnn(X_train, X_test): # function to evaluate RNN def evaluate_rnn(model, X_test): + """Evaluate a trained RNN model using mean squared error.""" + # reshape test data to match training shape + X_test = X_test.reshape((X_test.shape[0], 1, 1)) + # make predictions with RNN predictions = model.predict(X_test) # calculate mean squared error - mse = np.mean((predictions - X_test)**2) + mse = np.mean((predictions - X_test) ** 2) return mse # function to make predictions with RNN def predict_with_rnn(model, X_test): + """Generate predictions from a trained RNN model.""" + # reshape test data to match training shape + X_test = X_test.reshape((X_test.shape[0], 1, 1)) + # make predictions with RNN predictions = model.predict(X_test)