ArxAlfa commited on
Commit
8f7a0cd
1 Parent(s): a650593

Add DNN model and training code

Browse files
Files changed (3) hide show
  1. app.py +45 -8
  2. docker-compose.yml +3 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,8 @@
1
- from fastapi import FastAPI, UploadFile, File
 
 
2
  import numpy as np
3
- from sklearn.neural_network import MLPRegressor
4
  from sklearn.model_selection import KFold
5
  from sklearn.metrics import mean_squared_error
6
  import csv
@@ -8,8 +10,24 @@ import io
8
 
9
  from joblib import load, dump
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Load the model
12
- model = load("model.joblib")
13
 
14
  # Create a new FastAPI app instance
15
  app = FastAPI(docs_url="/", redoc_url="/new_redoc")
@@ -24,8 +42,11 @@ def generate(
24
  yearBuilt: float,
25
  ):
26
  global model
27
- prediction = model.predict([[squareFeet, bedrooms, bathrooms, yearBuilt]])
28
- return {"output": prediction[0]}
 
 
 
29
 
30
 
31
  @app.post("/train")
@@ -48,6 +69,14 @@ async def train(file: UploadFile = File(...)):
48
  y = data_np[:, -1]
49
  y = np.ravel(y)
50
 
 
 
 
 
 
 
 
 
51
  # Fit the model
52
  kf = KFold(n_splits=4)
53
  accuracies = []
@@ -56,10 +85,18 @@ async def train(file: UploadFile = File(...)):
56
  X_train, X_test = X[train_index], X[test_index]
57
  y_train, y_test = y[train_index], y[test_index]
58
 
59
- model.fit(X_train, y_train)
 
 
 
 
 
 
 
 
60
 
61
- predictions = model.predict(X_test)
62
- rmse = np.sqrt(mean_squared_error(y_test, predictions))
63
  accuracies.append(rmse)
64
 
65
  average_rmse = sum(accuracies) / len(accuracies)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
  import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File
6
  from sklearn.model_selection import KFold
7
  from sklearn.metrics import mean_squared_error
8
  import csv
 
10
 
11
  from joblib import load, dump
12
 
13
+
14
+ # Define the DNN model
15
+ class DNN(nn.Module):
16
+ def __init__(self, input_size, hidden_size, output_size):
17
+ super(DNN, self).__init__()
18
+ self.fc1 = nn.Linear(input_size, hidden_size)
19
+ self.relu = nn.ReLU()
20
+ self.fc2 = nn.Linear(hidden_size, output_size)
21
+
22
+ def forward(self, x):
23
+ x = self.fc1(x)
24
+ x = self.relu(x)
25
+ x = self.fc2(x)
26
+ return x
27
+
28
+
29
  # Load the model
30
+ model = DNN(input_size=4, hidden_size=16, output_size=1)
31
 
32
  # Create a new FastAPI app instance
33
  app = FastAPI(docs_url="/", redoc_url="/new_redoc")
 
42
  yearBuilt: float,
43
  ):
44
  global model
45
+ input_data = torch.tensor(
46
+ [[squareFeet, bedrooms, bathrooms, yearBuilt]], dtype=torch.float32
47
+ )
48
+ prediction = model(input_data)
49
+ return {"output": prediction.item()}
50
 
51
 
52
  @app.post("/train")
 
69
  y = data_np[:, -1]
70
  y = np.ravel(y)
71
 
72
+ # Convert data to torch tensors
73
+ X = torch.tensor(X, dtype=torch.float32)
74
+ y = torch.tensor(y, dtype=torch.float32)
75
+
76
+ # Define loss function and optimizer
77
+ criterion = nn.MSELoss()
78
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
79
+
80
  # Fit the model
81
  kf = KFold(n_splits=4)
82
  accuracies = []
 
85
  X_train, X_test = X[train_index], X[test_index]
86
  y_train, y_test = y[train_index], y[test_index]
87
 
88
+ optimizer.zero_grad()
89
+
90
+ # Forward pass
91
+ outputs = model(X_train)
92
+ loss = criterion(outputs, y_train.unsqueeze(1))
93
+
94
+ # Backward pass and optimization
95
+ loss.backward()
96
+ optimizer.step()
97
 
98
+ predictions = model(X_test)
99
+ rmse = np.sqrt(mean_squared_error(y_test, predictions.detach().numpy()))
100
  accuracies.append(rmse)
101
 
102
  average_rmse = sum(accuracies) / len(accuracies)
docker-compose.yml CHANGED
@@ -6,6 +6,9 @@ services:
6
  - "7860:7860"
7
  deploy:
8
  resources:
 
 
 
9
  reservations:
10
  devices:
11
  - driver: nvidia
 
6
  - "7860:7860"
7
  deploy:
8
  resources:
9
+ limits:
10
+ memory: 512M
11
+ shm_size: 2G
12
  reservations:
13
  devices:
14
  - driver: nvidia
requirements.txt CHANGED
@@ -28,3 +28,4 @@ uvicorn==0.17.6
28
  uvloop==0.19.0
29
  watchgod==0.8.2
30
  websockets==12.0
 
 
28
  uvloop==0.19.0
29
  watchgod==0.8.2
30
  websockets==12.0
31
+ torch