ThanaritKanjanametawat commited on
Commit
1bb143a
·
1 Parent(s): 1fb6014

Move everything to CPU

Browse files
Files changed (2) hide show
  1. ModelDriver.py +10 -9
  2. Test.py +1 -1
ModelDriver.py CHANGED
@@ -5,7 +5,8 @@ import torch.nn.functional as F
5
  from torch.utils.data import TensorDataset, DataLoader
6
 
7
 
8
- device = torch.device("cpu")
 
9
  class MLP(nn.Module):
10
  def __init__(self, input_dim):
11
  super(MLP, self).__init__()
@@ -62,14 +63,14 @@ def RobertaClassifierOpenGPTInference(input_text):
62
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
63
  model_path = "ClassifierCheckpoint/RobertaClassifierOpenGPT.pth"
64
  model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
65
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
66
- model = model.to(torch.device('cpu'))
67
  model.eval()
68
 
69
 
70
  tokenized_input = tokenizer(input_text, truncation=True, padding=True, max_length=512, return_tensors='pt')
71
- input_ids = tokenized_input['input_ids'].to(torch.device('cpu'))
72
- attention_mask = tokenized_input['attention_mask'].to(torch.device('cpu'))
73
 
74
  # Make a prediction
75
  with torch.no_grad():
@@ -84,14 +85,14 @@ def RobertaClassifierCSAbstractInference(input_text):
84
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
85
  model_path = "ClassifierCheckpoint/RobertaClassifierCSAbstract.pth"
86
  model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
87
- model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
88
- model = model.to(torch.device('cpu'))
89
  model.eval()
90
 
91
 
92
  tokenized_input = tokenizer(input_text, truncation=True, padding=True, max_length=512, return_tensors='pt')
93
- input_ids = tokenized_input['input_ids'].to(torch.device('cpu'))
94
- attention_mask = tokenized_input['attention_mask'].to(torch.device('cpu'))
95
 
96
  # Make a prediction
97
  with torch.no_grad():
 
5
  from torch.utils.data import TensorDataset, DataLoader
6
 
7
 
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ # device = torch.device("cpu")
10
  class MLP(nn.Module):
11
  def __init__(self, input_dim):
12
  super(MLP, self).__init__()
 
63
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
64
  model_path = "ClassifierCheckpoint/RobertaClassifierOpenGPT.pth"
65
  model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
66
+ model.load_state_dict(torch.load(model_path, map_location=device))
67
+ model = model.to(device)
68
  model.eval()
69
 
70
 
71
  tokenized_input = tokenizer(input_text, truncation=True, padding=True, max_length=512, return_tensors='pt')
72
+ input_ids = tokenized_input['input_ids'].to(device)
73
+ attention_mask = tokenized_input['attention_mask'].to(device)
74
 
75
  # Make a prediction
76
  with torch.no_grad():
 
85
  tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
86
  model_path = "ClassifierCheckpoint/RobertaClassifierCSAbstract.pth"
87
  model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)
88
+ model.load_state_dict(torch.load(model_path, map_location=device))
89
+ model = model.to(device)
90
  model.eval()
91
 
92
 
93
  tokenized_input = tokenizer(input_text, truncation=True, padding=True, max_length=512, return_tensors='pt')
94
+ input_ids = tokenized_input['input_ids'].to(device)
95
+ attention_mask = tokenized_input['attention_mask'].to(device)
96
 
97
  # Make a prediction
98
  with torch.no_grad():
Test.py CHANGED
@@ -20,7 +20,7 @@ Input_Text = "I want to do this data"
20
  # print(f"Confidence:", max(Probs))
21
 
22
  print("RobertaClassifierCSAbstractInference")
23
- Probs = RobertaClassifierOpenGPTInference(Input_Text)
24
  Pred = "Human Written" if not np.argmax(Probs) else "Machine Generated"
25
 
26
  print(Probs)
 
20
  # print(f"Confidence:", max(Probs))
21
 
22
  print("RobertaClassifierCSAbstractInference")
23
+ Probs = RobertaClassifierCSAbstractInference(Input_Text)
24
  Pred = "Human Written" if not np.argmax(Probs) else "Machine Generated"
25
 
26
  print(Probs)