not-lain commited on
Commit
0c3b8fb
β€’
1 Parent(s): 1f8acf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -10,7 +10,7 @@ model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",
10
  # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
11
  torch_dtype=torch.float16,
12
  token=token)
13
- tok = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
14
  # using CUDA for an optimal experience
15
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  device = torch.device('cuda')
@@ -25,11 +25,11 @@ def chat(message, history):
25
  if item[1] is not None:
26
  chat.append({"role": "assistant", "content": item[1]})
27
  chat.append({"role": "user", "content": message})
28
- messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
29
  # Tokenize the messages string
30
- model_inputs = tok([messages], return_tensors="pt").to(device)
31
  streamer = TextIteratorStreamer(
32
- tok, timeout=10., skip_prompt=True, skip_special_tokens=True)
33
  generate_kwargs = dict(
34
  model_inputs,
35
  streamer=streamer,
 
10
  # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
11
  torch_dtype=torch.float16,
12
  token=token)
13
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it",token=token)
14
  # using CUDA for an optimal experience
15
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
  device = torch.device('cuda')
 
25
  if item[1] is not None:
26
  chat.append({"role": "assistant", "content": item[1]})
27
  chat.append({"role": "user", "content": message})
28
+ messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
29
  # Tokenize the messages string
30
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
31
  streamer = TextIteratorStreamer(
32
+ tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
33
  generate_kwargs = dict(
34
  model_inputs,
35
  streamer=streamer,