saikiranmansa commited on
Commit
8e496ea
·
verified ·
1 Parent(s): 8ef6ee4

Upload 4 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ weatherAUS.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from sklearn.preprocessing import LabelEncoder, StandardScaler
7
+
8
+ st.markdown("<h1 style='text-align: center; font-size: 48px; color: #6699CC;'>Rain Tomorrow Prediction</h1>", unsafe_allow_html=True)
9
+
10
+ # Function to create cyclical features
11
+ def create_date_features(df, date_column='Date'):
12
+ df = df.copy()
13
+ df[date_column] = pd.to_datetime(df[date_column])
14
+
15
+ # Extract basic components
16
+ df['year'] = df[date_column].dt.year
17
+ month = df[date_column].dt.month
18
+ day = df[date_column].dt.day
19
+
20
+ # Create cyclical features
21
+ df['month_sin'] = np.sin(2 * np.pi * month/12)
22
+ df['month_cos'] = np.cos(2 * np.pi * month/12)
23
+ df['day_sin'] = np.sin(2 * np.pi * day/31)
24
+ df['day_cos'] = np.cos(2 * np.pi * day/31)
25
+
26
+ return df
27
+
28
+ # Load the dataset
29
+ @st.cache_data
30
+ def load_dataset():
31
+ df = pd.read_csv('weatherAUS.csv')
32
+ return create_date_features(df)
33
+
34
+ # Cache function to convert DataFrame to CSV
35
+ @st.cache_data
36
+ def convert_df(df):
37
+ return df.to_csv(index=False).encode("utf-8")
38
+
39
+ # Define the neural network model
40
+ class Enhanced_ANN_Model(nn.Module):
41
+ def __init__(self, input_dim):
42
+ super(Enhanced_ANN_Model, self).__init__()
43
+ self.fc1 = nn.Linear(input_dim, 128)
44
+ self.bn1 = nn.BatchNorm1d(128)
45
+ self.fc2 = nn.Linear(128, 64)
46
+ self.bn2 = nn.BatchNorm1d(64)
47
+ self.fc3 = nn.Linear(64, 32)
48
+ self.bn3 = nn.BatchNorm1d(32)
49
+ self.fc4 = nn.Linear(32, 1)
50
+
51
+ def forward(self, x):
52
+ x = self.fc1(x)
53
+ x = self.bn1(x)
54
+ x = torch.relu(x)
55
+ x = self.fc2(x)
56
+ x = self.bn2(x)
57
+ x = torch.relu(x)
58
+ x = self.fc3(x)
59
+ x = self.bn3(x)
60
+ x = torch.relu(x)
61
+ x = self.fc4(x)
62
+ return x
63
+
64
+ # Load pre-trained model
65
+ @st.cache_resource
66
+ def load_model():
67
+ input_dim = 26 # Changed to 26 features to match the trained model
68
+ model = Enhanced_ANN_Model(input_dim)
69
+
70
+ try:
71
+ state_dict = torch.load("model_weights.pth", map_location=torch.device('cpu'))
72
+ if isinstance(state_dict, dict):
73
+ model.load_state_dict(state_dict)
74
+ else:
75
+ model = state_dict
76
+ model.eval()
77
+ return model
78
+ except Exception as e:
79
+ st.markdown(f"<p style='color: #0000FF;'>Error loading model: {str(e)}</p>", unsafe_allow_html=True)
80
+ return None
81
+
82
+ # Load dataset
83
+ try:
84
+ df = load_dataset()
85
+
86
+ # Display dataset preview
87
+ st.markdown("<h3 style='color: #6699CC;'>Dataset Preview:</h3>", unsafe_allow_html=True)
88
+ st.dataframe(df.head())
89
+
90
+ # Base required columns
91
+ base_columns = ['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
92
+ 'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
93
+ 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
94
+ 'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm',
95
+ 'Temp9am', 'Temp3pm', 'RainToday']
96
+
97
+ # Add date-derived features
98
+ required_columns = base_columns + ['month_sin', 'month_cos', 'day_sin', 'day_cos', 'year']
99
+
100
+ if not all(col in df.columns for col in required_columns):
101
+ missing_columns = ', '.join(set(required_columns) - set(df.columns))
102
+ st.markdown(f"<p style='color: #6699CC;'>Missing required columns: {missing_columns}</p>", unsafe_allow_html=True)
103
+ else:
104
+ # Label Encoding for categorical columns
105
+ label_encoders = {}
106
+ categorical_cols = ['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday']
107
+ for col in categorical_cols:
108
+ le = LabelEncoder()
109
+ df[col] = df[col].fillna('missing')
110
+ df[col] = le.fit_transform(df[col].astype(str))
111
+ label_encoders[col] = le
112
+
113
+ # Standard Scaling for numerical features
114
+ scaler = StandardScaler()
115
+ numerical_cols = [col for col in required_columns if col not in categorical_cols]
116
+ df[numerical_cols] = df[numerical_cols].fillna(df[numerical_cols].mean())
117
+ df[numerical_cols] = scaler.fit_transform(df[numerical_cols])
118
+
119
+ # Select a row for prediction
120
+ st.markdown("<h3 style='color: #6699CC;'>Select a Row for Prediction:</h3>", unsafe_allow_html=True)
121
+ st.markdown("""
122
+ <style>
123
+ .stSelectbox label {
124
+ color: #ff6347; /* Set your desired color here */
125
+ }
126
+ </style>
127
+ """, unsafe_allow_html=True)
128
+
129
+ # Selectbox widget
130
+ selected_row_index = st.selectbox("Select a Row Index", options=range(len(df)), index=0)
131
+ predict_button = st.button("Predict Weather")
132
+
133
+ if predict_button:
134
+ model = load_model()
135
+ if model is not None:
136
+ # Get all required columns for prediction
137
+ row_to_use = df.iloc[selected_row_index][required_columns]
138
+
139
+ # Ensure all values are float32
140
+ row_tensor = torch.tensor(row_to_use.values.astype(np.float32)).unsqueeze(0)
141
+
142
+ # Make prediction
143
+ with torch.no_grad():
144
+ prediction = model(row_tensor).item()
145
+
146
+ # Apply sigmoid to get probability
147
+ prediction = torch.sigmoid(torch.tensor(prediction)).item()
148
+
149
+ # Display results
150
+ st.markdown("<h3 style='color: #32a852;'>Row selected for prediction:</h3>", unsafe_allow_html=True)
151
+ st.write(row_to_use)
152
+
153
+ result = "Rain Expected" if prediction >= 0.5 else "No Rain Expected"
154
+ probability = prediction * 100
155
+
156
+ st.markdown(f"<h3 style='color: #32a852;'>Rain Prediction Result: {result}</h3>", unsafe_allow_html=True)
157
+ st.markdown(f"<h3 style='color: #32a852;'>Probability of Rain: {probability:.2f}%</h3>", unsafe_allow_html=True)
158
+
159
+ # Show original date for reference
160
+ original_date = df.iloc[selected_row_index]['Date']
161
+ st.markdown(f"<h3 style='color: #32a852;'>Date: {original_date}</h3>", unsafe_allow_html=True)
162
+
163
+ # Provide download option
164
+ result_df = row_to_use.to_frame().T
165
+ result_df['Rain Prediction'] = result
166
+ result_df['Rain Probability'] = f"{probability:.2f}%"
167
+ result_df['Date'] = original_date
168
+ result_csv = convert_df(result_df)
169
+ st.download_button(
170
+ label="Download Prediction Result",
171
+ data=result_csv,
172
+ file_name="Rain_Prediction_Result.csv",
173
+ mime="text/csv",
174
+ )
175
+
176
+ except Exception as e:
177
+ st.markdown(f"<p style='color: #32a852;'>An error occurred: {str(e)}</p>", unsafe_allow_html=True)
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b04f6443504323945fb3790a4185750a93dcc900a2f5bfe0ec9359fe0826b0ae
3
+ size 66099
rain_prediction_notebook.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
weatherAUS.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:573fd715cd69fcacc4df32024d823b450ae3edaae7e8ff2eeb623adbed424014
3
+ size 14094055