Update README.md
Browse files
README.md
CHANGED
@@ -12,7 +12,7 @@ Model was trained with a `max_length` of `4096`, but the base model supports `81
|
|
12 |
### Example Code
|
13 |
```py
|
14 |
import torch
|
15 |
-
from transformers import
|
16 |
import json
|
17 |
from tqdm import tqdm
|
18 |
|
@@ -45,34 +45,19 @@ def load_json_or_jsonl(file_path):
|
|
45 |
return None
|
46 |
|
47 |
|
48 |
-
|
49 |
-
"
|
50 |
-
|
51 |
-
model = LlamaForSequenceClassification.from_pretrained(
|
52 |
-
"PJMixers/Danube3-ClassTest-v0.1-500M",
|
53 |
device_map="cuda",
|
54 |
-
torch_dtype=torch.bfloat16
|
55 |
-
attn_implementation="sdpa",
|
56 |
)
|
57 |
-
|
58 |
data = load_json_or_jsonl(
|
59 |
"./PrefMix-Classifier-Data-validation.json"
|
60 |
)
|
61 |
|
62 |
passes, fails = 0, 0
|
63 |
for sample in tqdm(data):
|
64 |
-
input_text
|
65 |
-
true_label = sample["labels"]
|
66 |
-
|
67 |
-
inputs = tokenizer(
|
68 |
-
input_text,
|
69 |
-
return_tensors="pt"
|
70 |
-
).to("cuda")
|
71 |
-
|
72 |
-
with torch.no_grad():
|
73 |
-
generated_label = model(**inputs).logits.argmax().item()
|
74 |
-
|
75 |
-
if generated_label == true_label:
|
76 |
passes += 1
|
77 |
else:
|
78 |
fails += 1
|
|
|
12 |
### Example Code
|
13 |
```py
|
14 |
import torch
|
15 |
+
from transformers import pipeline
|
16 |
import json
|
17 |
from tqdm import tqdm
|
18 |
|
|
|
45 |
return None
|
46 |
|
47 |
|
48 |
+
pipe = pipeline(
|
49 |
+
task="text-classification",
|
50 |
+
model="PJMixers/Danube3-ClassTest-v0.1-500M",
|
|
|
|
|
51 |
device_map="cuda",
|
52 |
+
torch_dtype=torch.bfloat16
|
|
|
53 |
)
|
|
|
54 |
data = load_json_or_jsonl(
|
55 |
"./PrefMix-Classifier-Data-validation.json"
|
56 |
)
|
57 |
|
58 |
passes, fails = 0, 0
|
59 |
for sample in tqdm(data):
|
60 |
+
if int(pipe(sample["input_text"])[0]["label"]) == sample["labels"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
passes += 1
|
62 |
else:
|
63 |
fails += 1
|