xzuyn commited on
Commit
66e7118
·
verified ·
1 Parent(s): caa9e11

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -21
README.md CHANGED
@@ -12,7 +12,7 @@ Model was trained with a `max_length` of `4096`, but the base model supports `81
12
  ### Example Code
13
  ```py
14
  import torch
15
- from transformers import AutoTokenizer, LlamaForSequenceClassification
16
  import json
17
  from tqdm import tqdm
18
 
@@ -45,34 +45,19 @@ def load_json_or_jsonl(file_path):
45
  return None
46
 
47
 
48
- tokenizer = AutoTokenizer.from_pretrained(
49
- "PJMixers/Danube3-ClassTest-v0.1-500M"
50
- )
51
- model = LlamaForSequenceClassification.from_pretrained(
52
- "PJMixers/Danube3-ClassTest-v0.1-500M",
53
  device_map="cuda",
54
- torch_dtype=torch.bfloat16,
55
- attn_implementation="sdpa",
56
  )
57
-
58
  data = load_json_or_jsonl(
59
  "./PrefMix-Classifier-Data-validation.json"
60
  )
61
 
62
  passes, fails = 0, 0
63
  for sample in tqdm(data):
64
- input_text = sample["input_text"]
65
- true_label = sample["labels"]
66
-
67
- inputs = tokenizer(
68
- input_text,
69
- return_tensors="pt"
70
- ).to("cuda")
71
-
72
- with torch.no_grad():
73
- generated_label = model(**inputs).logits.argmax().item()
74
-
75
- if generated_label == true_label:
76
  passes += 1
77
  else:
78
  fails += 1
 
12
  ### Example Code
13
  ```py
14
  import torch
15
+ from transformers import pipeline
16
  import json
17
  from tqdm import tqdm
18
 
 
45
  return None
46
 
47
 
48
+ pipe = pipeline(
49
+ task="text-classification",
50
+ model="PJMixers/Danube3-ClassTest-v0.1-500M",
 
 
51
  device_map="cuda",
52
+ torch_dtype=torch.bfloat16
 
53
  )
 
54
  data = load_json_or_jsonl(
55
  "./PrefMix-Classifier-Data-validation.json"
56
  )
57
 
58
  passes, fails = 0, 0
59
  for sample in tqdm(data):
60
+ if int(pipe(sample["input_text"])[0]["label"]) == sample["labels"]:
 
 
 
 
 
 
 
 
 
 
 
61
  passes += 1
62
  else:
63
  fails += 1