ybelkada commited on
Commit
34efb62
1 Parent(s): d69ff54
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -11,40 +11,37 @@ import torch
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
 
14
  StoppingCriteria,
15
  StoppingCriteriaList,
16
  TextIteratorStreamer,
17
  )
18
 
19
 
 
20
  model_name = "timdettmers/guanaco-33b-merged"
21
  max_new_tokens = 1536
22
 
23
- # # small testing model:
24
- model_name = "gpt2"
25
- max_new_tokens = 128
26
-
27
  auth_token = os.getenv("HF_TOKEN", None)
28
 
29
  print(f"Starting to load the model {model_name} into memory")
30
 
31
  m = AutoModelForCausalLM.from_pretrained(
32
  model_name,
33
- # load_in_8bit=True,
34
  torch_dtype=torch.bfloat16,
35
- # device_map="auto"
36
  )
37
- # tok = AutoTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
38
- tok = AutoTokenizer.from_pretrained(model_name)
39
  tok.bos_token_id = 1
40
 
41
- # stop_token_ids = tok.convert_tokens_to_ids(["<|im_end|>", "<|endoftext|>"])
42
 
43
  print(f"Successfully loaded the model {model_name} into memory")
44
 
45
 
46
  start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
47
- prompt = f"{start_message} ### Human: {user_query} ### Assistant:"
48
 
49
 
50
  class StopOnTokens(StoppingCriteria):
@@ -60,8 +57,8 @@ def convert_history_to_text(history):
60
  [
61
  "".join(
62
  [
63
- f"<|im_start|>user\n{item[0]}<|im_end|>",
64
- f"<|im_start|>assistant\n{item[1]}<|im_end|>",
65
  ]
66
  )
67
  for item in history[:-1]
@@ -71,8 +68,8 @@ def convert_history_to_text(history):
71
  [
72
  "".join(
73
  [
74
- f"<|im_start|>user\n{history[-1][0]}<|im_end|>",
75
- f"<|im_start|>assistant\n{history[-1][1]}",
76
  ]
77
  )
78
  ]
 
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
14
+ LlamaTokenizer,
15
  StoppingCriteria,
16
  StoppingCriteriaList,
17
  TextIteratorStreamer,
18
  )
19
 
20
 
21
+ # model_name = "lmsys/vicuna-7b-delta-v1.1"
22
  model_name = "timdettmers/guanaco-33b-merged"
23
  max_new_tokens = 1536
24
 
 
 
 
 
25
  auth_token = os.getenv("HF_TOKEN", None)
26
 
27
  print(f"Starting to load the model {model_name} into memory")
28
 
29
  m = AutoModelForCausalLM.from_pretrained(
30
  model_name,
31
+ load_in_8bit=True,
32
  torch_dtype=torch.bfloat16,
33
+ device_map={"": 0}
34
  )
35
+ tok = LlamaTokenizer.from_pretrained(model_name)
 
36
  tok.bos_token_id = 1
37
 
38
+ stop_token_ids = [0]
39
 
40
  print(f"Successfully loaded the model {model_name} into memory")
41
 
42
 
43
  start_message = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
44
+
45
 
46
 
47
  class StopOnTokens(StoppingCriteria):
 
57
  [
58
  "".join(
59
  [
60
+ f"### Human: {item[0]}\n",
61
+ f"### Assistant: {item[1]}\n",
62
  ]
63
  )
64
  for item in history[:-1]
 
68
  [
69
  "".join(
70
  [
71
+ f"### Human: {history[-1][0]}\n",
72
+ f"### Assistant: {history[-1][1]}\n",
73
  ]
74
  )
75
  ]