Update app.py
Browse files
app.py
CHANGED
@@ -20,8 +20,6 @@ n_head = 6
|
|
20 |
n_layer = 6
|
21 |
dropout = 0.2
|
22 |
|
23 |
-
# Define paths
|
24 |
-
weights_path = os.path.join('era_v2_assignment_19_model.pt')
|
25 |
|
26 |
torch.manual_seed(1337)
|
27 |
|
@@ -207,7 +205,7 @@ class GPTLanguageModel(nn.Module):
|
|
207 |
model = GPTLanguageModel()
|
208 |
|
209 |
# Load the saved state dictionary into the model
|
210 |
-
model.load_state_dict(torch.load('era_v2_assignment_19_model.pth'))
|
211 |
|
212 |
# Set the model to evaluation mode
|
213 |
model.eval()
|
|
|
20 |
n_layer = 6
|
21 |
dropout = 0.2
|
22 |
|
|
|
|
|
23 |
|
24 |
torch.manual_seed(1337)
|
25 |
|
|
|
205 |
model = GPTLanguageModel()
|
206 |
|
207 |
# Load the saved state dictionary into the model
|
208 |
+
model.load_state_dict(torch.load('era_v2_assignment_19_model.pth', map_location=torch.device('cpu')))
|
209 |
|
210 |
# Set the model to evaluation mode
|
211 |
model.eval()
|