gorkaartola commited on
Commit
6ef3524
1 Parent(s): adf171b

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +8 -18
run.py CHANGED
@@ -33,8 +33,9 @@ def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selecto
33
  return tokenize
34
 
35
  model = AutoModelForSequenceClassification.from_pretrained(model_selector)
36
- device = torch.device("cuda")
37
- model.to(device)
 
38
  tokenizer = AutoTokenizer.from_pretrained(model_selector)
39
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
40
 
@@ -56,7 +57,10 @@ def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selecto
56
  labels = []
57
  nli_labels =[]
58
  for batch in dataloader:
59
- data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
 
 
 
60
  with torch.no_grad():
61
  outputs = model(**data)
62
  logits = outputs.logits
@@ -100,18 +104,4 @@ def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selecto
100
  for result in results:
101
  results[result].to_csv(results_file, mode='a', index_label = result)
102
  print(results[result], '\n')
103
- return results
104
-
105
-
106
-
107
-
108
-
109
-
110
-
111
-
112
-
113
-
114
-
115
-
116
-
117
-
 
33
  return tokenize
34
 
35
  model = AutoModelForSequenceClassification.from_pretrained(model_selector)
36
+ if torch.cuda.is_available():
37
+ device = torch.device("cuda")
38
+ model.to(device)
39
  tokenizer = AutoTokenizer.from_pretrained(model_selector)
40
  data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
41
 
 
57
  labels = []
58
  nli_labels =[]
59
  for batch in dataloader:
60
+ if torch.cuda.is_available():
61
+ data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
62
+ else:
63
+ data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
64
  with torch.no_grad():
65
  outputs = model(**data)
66
  logits = outputs.logits
 
104
  for result in results:
105
  results[result].to_csv(results_file, mode='a', index_label = result)
106
  print(results[result], '\n')
107
+ return results