Spaces:
Running
Running
ThanaritKanjanametawat
commited on
Commit
·
582b2f2
1
Parent(s):
2287a5c
change the device to cpu only 3
Browse files- ModelDriver.py +3 -2
ModelDriver.py
CHANGED
@@ -2,6 +2,7 @@ from transformers import RobertaTokenizer, RobertaForSequenceClassification, Rob
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
|
|
|
5 |
device = torch.device("cpu")
|
6 |
class MLP(nn.Module):
|
7 |
def __init__(self, input_dim):
|
@@ -27,7 +28,7 @@ def extract_features(text):
|
|
27 |
def RobertaSentinelOpenGPTInference(input_text):
|
28 |
features = extract_features(input_text)
|
29 |
loaded_model = MLP(768).to(device)
|
30 |
-
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelOpenGPT.pth"))
|
31 |
|
32 |
# Define the tokenizer and model for feature extraction
|
33 |
with torch.no_grad():
|
@@ -40,7 +41,7 @@ def RobertaSentinelOpenGPTInference(input_text):
|
|
40 |
def RobertaSentinelCSAbstractInference(input_text):
|
41 |
features = extract_features(input_text)
|
42 |
loaded_model = MLP(768).to(device)
|
43 |
-
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelCSAbstract.pth"))
|
44 |
|
45 |
# Define the tokenizer and model for feature extraction
|
46 |
with torch.no_grad():
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
|
5 |
+
|
6 |
device = torch.device("cpu")
|
7 |
class MLP(nn.Module):
|
8 |
def __init__(self, input_dim):
|
|
|
28 |
def RobertaSentinelOpenGPTInference(input_text):
|
29 |
features = extract_features(input_text)
|
30 |
loaded_model = MLP(768).to(device)
|
31 |
+
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelOpenGPT.pth", map_location=device))
|
32 |
|
33 |
# Define the tokenizer and model for feature extraction
|
34 |
with torch.no_grad():
|
|
|
41 |
def RobertaSentinelCSAbstractInference(input_text):
|
42 |
features = extract_features(input_text)
|
43 |
loaded_model = MLP(768).to(device)
|
44 |
+
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelCSAbstract.pth", map_location=device))
|
45 |
|
46 |
# Define the tokenizer and model for feature extraction
|
47 |
with torch.no_grad():
|