{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":" New york taxi GRU_LSTM monthly_preprocessing .ipynb","provenance":[{"file_id":"15dcbg9FseoYPBRIO1F0dtjnP3P1Puniq","timestamp":1631964684465},{"file_id":"1q4otCUdBI8Y0yRo9Tp59GS-8NJpdBpjN","timestamp":1631960919368},{"file_id":"1sW5Y9qlaJuMCfPHwrWBG9qcBqQXuxi-4","timestamp":1631562729044}],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","metadata":{"id":"ANisF3yfliXJ"},"source":["import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from sklearn.model_selection import train_test_split\n","from sklearn.metrics import mean_squared_error,r2_score\n","from sklearn import linear_model\n","\n","import torch\n","import torch.nn as nn\n","from torch.autograd import Variable\n","import folium\n","from folium import FeatureGroup, LayerControl, Map, Marker\n","from folium.plugins import HeatMap\n","from folium.plugins import TimestampedGeoJson\n","from folium.plugins import MarkerCluster"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ViBRhX2slsYR","executionInfo":{"status":"ok","timestamp":1641730076996,"user_tz":-60,"elapsed":18222,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"9e0bc7e6-4a39-44c1-d413-36bfd92b27df"},"source":["from google.colab import drive\n","drive.mount(\"/content/gdrive\")"],"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive\n"]}]},{"cell_type":"code","metadata":{"id":"1R8YVA4MuWL6","executionInfo":{"status":"ok","timestamp":1641730101346,"user_tz":-60,"elapsed":24033,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["df_july = pd.read_csv('/content/gdrive/My Drive/Taxi Demand Prediction/Datasets/yellow_tripdata_2016-07.csv')"],"execution_count":5,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":444},"id":"pw9zxJTBuwTp","executionInfo":{"status":"ok","timestamp":1641730105177,"user_tz":-60,"elapsed":433,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"e1722d54-536b-4827-ffae-009269039791"},"source":["df_july.head()"],"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n","
\n","
\n","
\n","\n","
\n"," \n"," \n"," | \n"," | \n"," VendorID | \n"," tpep_pickup_datetime | \n"," tpep_dropoff_datetime | \n"," passenger_count | \n"," trip_distance | \n"," RatecodeID | \n"," store_and_fwd_flag | \n"," PULocationID | \n"," DOLocationID | \n"," payment_type | \n"," fare_amount | \n"," extra | \n"," mta_tax | \n"," tip_amount | \n"," tolls_amount | \n"," improvement_surcharge | \n"," total_amount | \n","
\n"," \n"," \n"," \n"," 1 | \n"," 2016-07-10 06:56:05 | \n"," 2016-07-10 06:59:53 | \n"," 1 | \n"," 0.50 | \n"," 1 | \n"," N | \n"," 263 | \n"," 236 | \n"," 1 | \n"," 4.5 | \n"," 1.0 | \n"," 0.5 | \n"," 2.70 | \n"," 0.0 | \n"," 0.3 | \n"," 9.00 | \n"," NaN | \n"," NaN | \n","
\n"," \n"," 2 | \n"," 2016-07-10 10:50:18 | \n"," 2016-07-10 10:55:21 | \n"," 5 | \n"," 1.34 | \n"," 1 | \n"," N | \n"," 142 | \n"," 163 | \n"," 1 | \n"," 6.0 | \n"," 0.0 | \n"," 0.5 | \n"," 1.36 | \n"," 0.0 | \n"," 0.3 | \n"," 8.16 | \n"," NaN | \n"," NaN | \n","
\n"," \n"," 2016-07-10 10:50:18 | \n"," 2016-07-10 11:08:38 | \n"," 1 | \n"," 9.48 | \n"," 1 | \n"," N | \n"," 74 | \n"," 66 | \n"," 1 | \n"," 27.0 | \n"," 0.0 | \n"," 0.5 | \n"," 0.00 | \n"," 0.0 | \n"," 0.3 | \n"," 27.80 | \n"," NaN | \n"," NaN | \n","
\n"," \n"," 1 | \n"," 2016-07-10 10:50:19 | \n"," 2016-07-10 10:55:14 | \n"," 1 | \n"," 1.00 | \n"," 1 | \n"," N | \n"," 264 | \n"," 264 | \n"," 2 | \n"," 5.5 | \n"," 0.0 | \n"," 0.5 | \n"," 0.00 | \n"," 0.0 | \n"," 0.3 | \n"," 6.30 | \n"," NaN | \n"," NaN | \n","
\n"," \n"," 2016-07-10 10:50:19 | \n"," 2016-07-10 10:55:47 | \n"," 1 | \n"," 0.90 | \n"," 1 | \n"," N | \n"," 48 | \n"," 68 | \n"," 2 | \n"," 5.5 | \n"," 0.0 | \n"," 0.5 | \n"," 0.00 | \n"," 0.0 | \n"," 0.3 | \n"," 6.30 | \n"," NaN | \n"," NaN | \n","
\n"," \n","
\n","
\n","
\n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" VendorID ... total_amount\n","1 2016-07-10 06:56:05 2016-07-10 06:59:53 ... NaN\n","2 2016-07-10 10:50:18 2016-07-10 10:55:21 ... NaN\n"," 2016-07-10 10:50:18 2016-07-10 11:08:38 ... NaN\n","1 2016-07-10 10:50:19 2016-07-10 10:55:14 ... NaN\n"," 2016-07-10 10:50:19 2016-07-10 10:55:47 ... NaN\n","\n","[5 rows x 17 columns]"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"SiDk1qvhdmNL","executionInfo":{"status":"ok","timestamp":1641730109263,"user_tz":-60,"elapsed":422,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"4a4b84f3-be7c-47c1-8e73-55ce7983716d"},"source":["df_july.shape[0]"],"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["10294080"]},"metadata":{},"execution_count":7}]},{"cell_type":"code","metadata":{"id":"XFEZZxDzmBvc","executionInfo":{"status":"ok","timestamp":1641730113564,"user_tz":-60,"elapsed":421,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["df_july = df_july.drop(['RatecodeID','payment_type','fare_amount','extra','mta_tax','tip_amount','tolls_amount','improvement_surcharge','total_amount'],1)"],"execution_count":8,"outputs":[]},{"cell_type":"code","metadata":{"id":"dPtcgGGKmZkD","executionInfo":{"status":"ok","timestamp":1641730117893,"user_tz":-60,"elapsed":1371,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["df_july.dropna(how = 'any', inplace = True)"],"execution_count":9,"outputs":[]},{"cell_type":"code","metadata":{"id":"8tAjCysJmiMv"},"source":["# taxi_carry = np.array(np.where(df_jan['passenger_count']>4))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"pXfZQENCmoWt","executionInfo":{"status":"ok","timestamp":1641730119835,"user_tz":-60,"elapsed":341,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["df_july['pickup_datetime'] = pd.to_datetime(df_july['tpep_pickup_datetime'])"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"id":"djmnkLqTvC1t","executionInfo":{"status":"ok","timestamp":1641730121569,"user_tz":-60,"elapsed":363,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["df_july['demand'] = df_july['passenger_count']*(df_july['trip_distance']) #demand ratio"],"execution_count":11,"outputs":[]},{"cell_type":"code","metadata":{"id":"aY_12dWLz4-Q","executionInfo":{"status":"ok","timestamp":1641730403763,"user_tz":-60,"elapsed":280220,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["df_july.to_csv('/content/gdrive/My Drive/Taxi Demand Prediction/Datasets/yellow_tripdata_2016-07_processed.csv')"],"execution_count":12,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":337},"id":"oeK5ASrbvPzC","executionInfo":{"status":"ok","timestamp":1641730427738,"user_tz":-60,"elapsed":428,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"ddc37257-c753-4bb8-ac11-2bbd9eb3e756"},"source":["df_july.head()"],"execution_count":13,"outputs":[{"output_type":"execute_result","data":{"text/html":["\n"," \n","
\n","
\n","\n","
\n"," \n"," \n"," | \n"," | \n"," VendorID | \n"," tpep_pickup_datetime | \n"," tpep_dropoff_datetime | \n"," passenger_count | \n"," trip_distance | \n"," store_and_fwd_flag | \n"," PULocationID | \n"," DOLocationID | \n"," pickup_datetime | \n"," demand | \n","
\n"," \n"," \n"," \n"," 1 | \n"," 2016-07-10 06:56:05 | \n"," 2016-07-10 06:59:53 | \n"," 1 | \n"," 0.50 | \n"," 1 | \n"," N | \n"," 236 | \n"," 1 | \n"," 4.5 | \n"," 1970-01-01 00:00:00.000000001 | \n"," N | \n","
\n"," \n"," 2 | \n"," 2016-07-10 10:50:18 | \n"," 2016-07-10 10:55:21 | \n"," 5 | \n"," 1.34 | \n"," 1 | \n"," N | \n"," 163 | \n"," 1 | \n"," 6.0 | \n"," 1970-01-01 00:00:00.000000005 | \n"," N | \n","
\n"," \n"," 2016-07-10 10:50:18 | \n"," 2016-07-10 11:08:38 | \n"," 1 | \n"," 9.48 | \n"," 1 | \n"," N | \n"," 66 | \n"," 1 | \n"," 27.0 | \n"," 1970-01-01 00:00:00.000000001 | \n"," N | \n","
\n"," \n"," 1 | \n"," 2016-07-10 10:50:19 | \n"," 2016-07-10 10:55:14 | \n"," 1 | \n"," 1.00 | \n"," 1 | \n"," N | \n"," 264 | \n"," 2 | \n"," 5.5 | \n"," 1970-01-01 00:00:00.000000001 | \n"," N | \n","
\n"," \n"," 2016-07-10 10:50:19 | \n"," 2016-07-10 10:55:47 | \n"," 1 | \n"," 0.90 | \n"," 1 | \n"," N | \n"," 68 | \n"," 2 | \n"," 5.5 | \n"," 1970-01-01 00:00:00.000000001 | \n"," N | \n","
\n"," \n","
\n","
\n","
\n"," \n"," \n","\n"," \n","
\n","
\n"," "],"text/plain":[" VendorID ... demand\n","1 2016-07-10 06:56:05 2016-07-10 06:59:53 ... N\n","2 2016-07-10 10:50:18 2016-07-10 10:55:21 ... N\n"," 2016-07-10 10:50:18 2016-07-10 11:08:38 ... N\n","1 2016-07-10 10:50:19 2016-07-10 10:55:14 ... N\n"," 2016-07-10 10:50:19 2016-07-10 10:55:47 ... N\n","\n","[5 rows x 10 columns]"]},"metadata":{},"execution_count":13}]},{"cell_type":"code","metadata":{"id":"njbijJnem1Fw","executionInfo":{"status":"ok","timestamp":1641730437526,"user_tz":-60,"elapsed":882,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["X = df_july.drop(['VendorID','passenger_count','trip_distance','store_and_fwd_flag'], axis = 1)\n","\n","y = df_july['demand']"],"execution_count":14,"outputs":[]},{"cell_type":"code","metadata":{"id":"oeY2TOs7nBn0","executionInfo":{"status":"ok","timestamp":1641730441043,"user_tz":-60,"elapsed":404,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["cont_cols = ['PULocationID',\n"," 'DOLocationID']\n"," \n","conts_data = np.stack([df_july[col].values for col in cont_cols], 1)"],"execution_count":15,"outputs":[]},{"cell_type":"code","metadata":{"id":"Wv-1z8hRnDRB","executionInfo":{"status":"ok","timestamp":1641730443420,"user_tz":-60,"elapsed":2,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}}},"source":["X = torch.tensor(conts_data, dtype = torch.float)"],"execution_count":16,"outputs":[]},{"cell_type":"code","metadata":{"id":"JcD8AwCBnFLF","colab":{"base_uri":"https://localhost:8080/","height":170},"executionInfo":{"status":"error","timestamp":1641730481914,"user_tz":-60,"elapsed":572,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"d3a8e3e9-c24a-4fda-ed9c-bf71762ed9c1"},"source":["y = torch.tensor(df_july['demand'].values, dtype=torch.float).reshape(-1,1)"],"execution_count":18,"outputs":[{"output_type":"error","ename":"TypeError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf_july\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'demand'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."]}]},{"cell_type":"code","metadata":{"id":"tUdNi7JqnG1V"},"source":["class LSTM1(nn.Module):\n"," \"\"\"LSTM architecture\"\"\"\n","\n"," def __init__(self, input_size, hidden_size, num_layers, seq_length=1):\n"," super(LSTM1, self).__init__()\n"," self.input_size = input_size # input size\n"," self.hidden_size = hidden_size # hidden state\n"," self.num_layers = num_layers # number of layers\n"," self.seq_length = seq_length # sequence length\n","\n"," self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,\n"," dropout=0.1)\n"," self.fc_1 = nn.Linear(hidden_size, 16) # fully connected 1\n"," self.fc_2 = nn.Linear(16, 8) # fully connected 2\n"," self.fc = nn.Linear(8, 1) # fully connected last layer\n","\n"," self.dropout = nn.Dropout(0.1)\n"," self.relu = nn.ReLU()\n","\n"," def forward(self, x):\n"," \"\"\"\n","\n"," :param x: input features\n"," :return: prediction results\n"," \"\"\"\n"," x = x.unsqueeze(0)\n"," h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) # hidden state\n"," c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size)) # internal state\n"," output, (hn, cn) = self.lstm(x, (h_0, c_0)) # lstm with input, hidden, and internal state\n","\n"," hn_o = torch.Tensor(hn.detach().numpy()[-1, :, :])\n"," hn_o = hn_o.view(-1, self.hidden_size)\n"," hn_1 = torch.Tensor(hn.detach().numpy()[1, :, :])\n"," hn_1 = hn_1.view(-1, self.hidden_size)\n","\n"," out = self.relu(self.fc_1(self.relu(hn_o + hn_1)))\n"," out = self.relu(self.fc_2(out))\n"," out = self.dropout(out)\n"," out = self.fc(out)\n"," return out\n"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"W_jQ8Stjnw_3"},"source":["class GRUModel(nn.Module):\n"," def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, dropout_prob):\n"," super(GRUModel, self).__init__()\n","\n"," # Defining the number of layers and the nodes in each layer\n"," self.layer_dim = layer_dim\n"," self.hidden_dim = hidden_dim\n","\n"," # GRU layers\n"," self.gru = nn.GRU(\n"," input_dim, hidden_dim, layer_dim, batch_first=True, dropout=dropout_prob\n"," )\n","\n"," # Fully connected layer\n"," self.fc = nn.Linear(hidden_dim, output_dim)\n","\n"," def forward(self, x):\n"," x = x.unsqueeze(0) #this removes the error \"input must have 3 dimensions, got 2\"\n"," # Initializing hidden state for first input with zeros\n"," h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()\n","\n"," # Forward propagation by passing in the input and hidden state into the model\n"," out, _ = self.gru(x, h0.detach())\n","\n"," # Reshaping the outputs in the shape of (batch_size, seq_length, hidden_size)\n"," # so that it can fit into the fully connected layer\n"," out = out[:, -1, :]\n","\n"," # Convert the final state to our desired output shape (batch_size, output_dim)\n"," out = self.fc(out)\n","\n"," return out"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"N6lXYb_DobtI","executionInfo":{"status":"ok","timestamp":1631964358537,"user_tz":-120,"elapsed":204,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"3e2c059c-ca42-4a7c-86fd-97af312485ae"},"source":["X.shape[1]"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["4"]},"metadata":{},"execution_count":27}]},{"cell_type":"code","metadata":{"id":"bXcFCk_snJhb"},"source":["model = LSTM1(X.shape[1], 16, 2)\n","model_2 = GRUModel(X.shape[1], 16, 2, 1, 0.1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ille-yPTnL4B"},"source":["criterion = nn.SmoothL1Loss()\n","optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"6ijsOfMinOfG"},"source":["batch_size = 10000\n","test_size = int(batch_size * .2)\n","\n","X_train = X[:batch_size-test_size]\n","X_test = X[batch_size-test_size:batch_size]\n","y_train = y[:batch_size-test_size]\n","y_test = y[batch_size-test_size:batch_size]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tg7wxTvwnQQQ","executionInfo":{"status":"ok","timestamp":1631964369747,"user_tz":-120,"elapsed":207,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"7303b89f-d7b1-4974-cb16-3bf567c85b51"},"source":["print(X_train.shape)\n","print(y_train.shape)\n","print(X_test.shape)\n","print(y_test.shape)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["torch.Size([8000, 4])\n","torch.Size([8000, 1])\n","torch.Size([2000, 4])\n","torch.Size([2000, 1])\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pqv5BBqgnSVf","executionInfo":{"status":"ok","timestamp":1631964381364,"user_tz":-120,"elapsed":9294,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"5f5e467b-614c-4bad-cd26-9ba05c270fdb"},"source":["import time\n","start_time = time.time()\n","\n","epochs = 10\n","losses = []\n","\n","for i in range(epochs):\n"," i+=1\n"," y_pred = model(X_train)\n"," loss = criterion(y_pred, y_train) # RMSE\n"," losses.append(loss)\n"," \n"," \n"," if i%25 == 1:\n"," print(f'epoch: {i:3} loss: {loss.item():10.8f}')\n","\n"," optimizer.zero_grad()\n"," loss.backward()\n"," optimizer.step()\n","\n","print(f'epoch: {i:3} loss: {loss.item():10.8f}') \n","print(f'\\nDuration: {time.time() - start_time:.0f} seconds') "],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py:921: UserWarning: Using a target size (torch.Size([8000, 1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n"," return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)\n"]},{"output_type":"stream","name":"stdout","text":["epoch: 1 loss: 4.69094992\n","epoch: 10 loss: 4.21567631\n","\n","Duration: 9 seconds\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":279},"id":"fTeXMiWynVE_","executionInfo":{"status":"ok","timestamp":1631964386029,"user_tz":-120,"elapsed":680,"user":{"displayName":"Faheem Ahmed Abbasi","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GhgtOF28sy7gFBpnhma6GfiWjvqUh1r3ss-fhJtog=s64","userId":"04247060777971177506"}},"outputId":"456d74bf-dd1e-46c1-ca62-f8bf7e088c46"},"source":["plt.plot(range(epochs), losses)\n","plt.ylabel('RMSE Loss')\n","plt.xlabel('epoch');"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"\n","text/plain":["