Spaces:
Runtime error
Runtime error
Commit
·
6ef3524
1
Parent(s):
adf171b
Update run.py
Browse files
run.py
CHANGED
@@ -33,8 +33,9 @@ def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selecto
|
|
33 |
return tokenize
|
34 |
|
35 |
model = AutoModelForSequenceClassification.from_pretrained(model_selector)
|
36 |
-
|
37 |
-
|
|
|
38 |
tokenizer = AutoTokenizer.from_pretrained(model_selector)
|
39 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
40 |
|
@@ -56,7 +57,10 @@ def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selecto
|
|
56 |
labels = []
|
57 |
nli_labels =[]
|
58 |
for batch in dataloader:
|
59 |
-
|
|
|
|
|
|
|
60 |
with torch.no_grad():
|
61 |
outputs = model(**data)
|
62 |
logits = outputs.logits
|
@@ -100,18 +104,4 @@ def tp_tf_test(model_selector, queries_selector, prompt_selector, metric_selecto
|
|
100 |
for result in results:
|
101 |
results[result].to_csv(results_file, mode='a', index_label = result)
|
102 |
print(results[result], '\n')
|
103 |
-
return results
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
33 |
return tokenize
|
34 |
|
35 |
model = AutoModelForSequenceClassification.from_pretrained(model_selector)
|
36 |
+
if torch.cuda.is_available():
|
37 |
+
device = torch.device("cuda")
|
38 |
+
model.to(device)
|
39 |
tokenizer = AutoTokenizer.from_pretrained(model_selector)
|
40 |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
41 |
|
|
|
57 |
labels = []
|
58 |
nli_labels =[]
|
59 |
for batch in dataloader:
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
data = {k: v.to(device) for k, v in batch.items() if k not in ['labels', 'nli_label']}
|
62 |
+
else:
|
63 |
+
data = {k: v for k, v in batch.items() if k not in ['labels', 'nli_label']}
|
64 |
with torch.no_grad():
|
65 |
outputs = model(**data)
|
66 |
logits = outputs.logits
|
|
|
104 |
for result in results:
|
105 |
results[result].to_csv(results_file, mode='a', index_label = result)
|
106 |
print(results[result], '\n')
|
107 |
+
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|