Update README.md
Browse files
README.md
CHANGED
@@ -30,12 +30,14 @@ from PIL import Image
|
|
30 |
|
31 |
import torch
|
32 |
from transformers import AutoProcessor
|
33 |
-
from heron.models.git_llm.
|
34 |
|
35 |
device_id = 0
|
36 |
|
37 |
# prepare a pretrained model
|
38 |
-
model =
|
|
|
|
|
39 |
model.eval()
|
40 |
model.to(f"cuda:{device_id}")
|
41 |
|
@@ -68,7 +70,7 @@ with torch.no_grad():
|
|
68 |
out = model.generate(**inputs, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list)
|
69 |
|
70 |
# print result
|
71 |
-
print(processor.tokenizer.batch_decode(out))
|
72 |
```
|
73 |
|
74 |
|
|
|
30 |
|
31 |
import torch
|
32 |
from transformers import AutoProcessor
|
33 |
+
from heron.models.git_llm.git_gpt_neox import GitGPTNeoXForCausalLM
|
34 |
|
35 |
device_id = 0
|
36 |
|
37 |
# prepare a pretrained model
|
38 |
+
model = GitGPTNeoXForCausalLM.from_pretrained(
|
39 |
+
'turing-motors/heron-chat-git-ELYZA-fast-7b-v0', torch_dtype=torch.float16
|
40 |
+
)
|
41 |
model.eval()
|
42 |
model.to(f"cuda:{device_id}")
|
43 |
|
|
|
70 |
out = model.generate(**inputs, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list)
|
71 |
|
72 |
# print result
|
73 |
+
print(processor.tokenizer.batch_decode(out)[0])
|
74 |
```
|
75 |
|
76 |
|