working v
Browse files
app.py
CHANGED
@@ -11,40 +11,37 @@ import torch
|
|
11 |
from transformers import (
|
12 |
AutoModelForCausalLM,
|
13 |
AutoTokenizer,
|
|
|
14 |
StoppingCriteria,
|
15 |
StoppingCriteriaList,
|
16 |
TextIteratorStreamer,
|
17 |
)
|
18 |
|
19 |
|
|
|
20 |
model_name = "timdettmers/guanaco-33b-merged"
|
21 |
max_new_tokens = 1536
|
22 |
|
23 |
-
# # small testing model:
|
24 |
-
model_name = "gpt2"
|
25 |
-
max_new_tokens = 128
|
26 |
-
|
27 |
auth_token = os.getenv("HF_TOKEN", None)
|
28 |
|
29 |
print(f"Starting to load the model {model_name} into memory")
|
30 |
|
31 |
m = AutoModelForCausalLM.from_pretrained(
|
32 |
model_name,
|
33 |
-
|
34 |
torch_dtype=torch.bfloat16,
|
35 |
-
|
36 |
)
|
37 |
-
|
38 |
-
tok = AutoTokenizer.from_pretrained(model_name)
|
39 |
tok.bos_token_id = 1
|
40 |
|
41 |
-
|
42 |
|
43 |
print(f"Successfully loaded the model {model_name} into memory")
|
44 |
|
45 |
|
46 |
start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
47 |
-
|
48 |
|
49 |
|
50 |
class StopOnTokens(StoppingCriteria):
|
@@ -60,8 +57,8 @@ def convert_history_to_text(history):
|
|
60 |
[
|
61 |
"".join(
|
62 |
[
|
63 |
-
f"
|
64 |
-
f"
|
65 |
]
|
66 |
)
|
67 |
for item in history[:-1]
|
@@ -71,8 +68,8 @@ def convert_history_to_text(history):
|
|
71 |
[
|
72 |
"".join(
|
73 |
[
|
74 |
-
f"
|
75 |
-
f"
|
76 |
]
|
77 |
)
|
78 |
]
|
|
|
11 |
from transformers import (
|
12 |
AutoModelForCausalLM,
|
13 |
AutoTokenizer,
|
14 |
+
LlamaTokenizer,
|
15 |
StoppingCriteria,
|
16 |
StoppingCriteriaList,
|
17 |
TextIteratorStreamer,
|
18 |
)
|
19 |
|
20 |
|
21 |
+
# model_name = "lmsys/vicuna-7b-delta-v1.1"
|
22 |
model_name = "timdettmers/guanaco-33b-merged"
|
23 |
max_new_tokens = 1536
|
24 |
|
|
|
|
|
|
|
|
|
25 |
auth_token = os.getenv("HF_TOKEN", None)
|
26 |
|
27 |
print(f"Starting to load the model {model_name} into memory")
|
28 |
|
29 |
m = AutoModelForCausalLM.from_pretrained(
|
30 |
model_name,
|
31 |
+
load_in_8bit=True,
|
32 |
torch_dtype=torch.bfloat16,
|
33 |
+
device_map={"": 0}
|
34 |
)
|
35 |
+
tok = LlamaTokenizer.from_pretrained(model_name)
|
|
|
36 |
tok.bos_token_id = 1
|
37 |
|
38 |
+
stop_token_ids = [0]
|
39 |
|
40 |
print(f"Successfully loaded the model {model_name} into memory")
|
41 |
|
42 |
|
43 |
start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
44 |
+
|
45 |
|
46 |
|
47 |
class StopOnTokens(StoppingCriteria):
|
|
|
57 |
[
|
58 |
"".join(
|
59 |
[
|
60 |
+
f"### Human: {item[0]}\n",
|
61 |
+
f"### Assistant: {item[1]}\n",
|
62 |
]
|
63 |
)
|
64 |
for item in history[:-1]
|
|
|
68 |
[
|
69 |
"".join(
|
70 |
[
|
71 |
+
f"### Human: {history[-1][0]}\n",
|
72 |
+
f"### Assistant: {history[-1][1]}\n",
|
73 |
]
|
74 |
)
|
75 |
]
|