Update README.md
Browse files
README.md
CHANGED
@@ -21,13 +21,30 @@ It achieves the following results on the evaluation set:
|
|
21 |
|
22 |
## How to Run Inference
|
23 |
|
24 |
-
Make sure you have git-lfs, and access to gemma-2b on huggingface.
|
25 |
```
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
```
|
30 |
-
replace "cpu" with "gpu" if you want to run on gpu.
|
31 |
|
32 |
## Intended uses & limitations
|
33 |
|
|
|
21 |
|
22 |
## How to Run Inference
|
23 |
|
|
|
24 |
```
|
25 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
26 |
+
|
27 |
+
model_id = "google/gemma-2b"
|
28 |
+
peft_model_id = "apfurman/gemma-dolly-agriculture"
|
29 |
+
|
30 |
+
# make sure you have access to gemma-2b as well
|
31 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, token="YOUR_TOKEN_HERE")
|
32 |
+
model.load_adapter(peft_model_id)
|
33 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token="YOUR_TOKEN_HERE")
|
34 |
+
|
35 |
+
def ask(prompt):
|
36 |
+
inputs = tokenizer(prompt, return_tensors="pt").input_ids
|
37 |
+
with torch.inference_mode():
|
38 |
+
tokens = model.generate(
|
39 |
+
inputs,
|
40 |
+
pad_token_id=128001,
|
41 |
+
eos_token_id=128001,
|
42 |
+
max_new_tokens=200,
|
43 |
+
repetition_penalty=1.5,
|
44 |
+
)
|
45 |
+
|
46 |
+
return tokenizer.decode(tokens[0], skip_special_tokens=True)
|
47 |
```
|
|
|
48 |
|
49 |
## Intended uses & limitations
|
50 |
|