hysts HF staff commited on
Commit
0737a9d
1 Parent(s): 8824f88
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -39,10 +39,11 @@ def generate(
39
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
  conversation.append({"role": "user", "content": message})
41
 
42
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to("cuda")
43
- if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
44
- input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
45
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
46
 
47
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
48
  generate_kwargs = dict(
 
39
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
  conversation.append({"role": "user", "content": message})
41
 
42
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
43
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
46
+ input_ids = input_ids.to(model.device)
47
 
48
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
49
  generate_kwargs = dict(