akkshay commited on
Commit
f54d8e4
·
1 Parent(s): c25a1b8

Wrote ReadMe file to track instructions to run inference

Browse files
Files changed (1) hide show
  1. README.md +75 -3
README.md CHANGED
@@ -1,3 +1,75 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hyde LLaMa 2 7B Legal
2
+
3
+ ## Model Details
4
+
5
+ **Backbone Model:** meta-llama/Llama-2-7b-chat
6
+
7
+ **Input:** Models input text only.
8
+
9
+ **Output:** Models generate text only.
10
+
11
+ ### Inference
12
+
13
+ ```python
14
+ def hyde_gen(
15
+ topic:str,
16
+ model:object,
17
+ tokenizer:object,
18
+ device:object
19
+ ):
20
+ prompt = (
21
+ f"Write legal facts about the following topic:\n{topic}\n"
22
+ )
23
+ len_prompt = len(prompt)
24
+
25
+ output = model.generate(
26
+ **tokenizer(prompt,
27
+ return_tensors="pt",
28
+ return_token_type_ids=False).to(device
29
+ ),
30
+ max_new_tokens=300,
31
+ early_stopping=True,
32
+ do_sample=True,
33
+ top_k=10,
34
+ top_p=0.98,
35
+ no_repeat_ngram_size=3,
36
+ eos_token_id=2,
37
+ repetition_penalty=1.1,
38
+ num_beams=3,
39
+ )
40
+
41
+ return tokenizer.decode(output[0])[len_prompt:]
42
+
43
+ def hyde_infer(input_topic):
44
+ device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
45
+
46
+ model_pth = "akkshay/hyde-llama-7b"
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_pth,
49
+ device_map={"": 0},
50
+ torch_dtype=torch.float16,
51
+ low_cpu_mem_usage=True
52
+ )
53
+ tokenizer = AutoTokenizer.from_pretrained(model_pth)
54
+
55
+ model.eval()
56
+ model.config.use_cache = (True)
57
+ tokenizer.pad_token = tokenizer.eos_token
58
+ output = hyde_gen(
59
+ topic=input_topic,
60
+ model=model,
61
+ tokenizer=tokenizer,
62
+ device=device
63
+ )
64
+
65
+ return output
66
+
67
+
68
+ if __name__ == "__main__":
69
+ fact = hyde_infer("VW emissions scandal")
70
+ print(fact)
71
+ ```
72
+
73
+ Since Hyde Llama 2 uses 'FastTokenizer' provided by HF tokenizers NOT sentencepiece package, it is required to use `use_fast=True` option when we initialize the tokenizer.
74
+
75
+ Lastly Apple M1/M2 chips does not support BF16 computing, so use CPU instead.