hysts HF staff commited on
Commit
6111f2c
1 Parent(s): f76edaf
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -64,15 +64,15 @@ def generate(
64
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
65
  conversation.append({"role": "user", "content": message})
66
 
67
- chat = tokenizer.apply_chat_template(conversation, tokenize=False)
68
- inputs = tokenizer(chat, return_tensors="pt", add_special_tokens=False).to("cuda")
69
- if len(inputs) > MAX_INPUT_TOKEN_LENGTH:
70
- inputs = inputs[-MAX_INPUT_TOKEN_LENGTH:]
71
- gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
72
 
73
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
74
  generate_kwargs = dict(
75
- inputs,
76
  streamer=streamer,
77
  max_new_tokens=max_new_tokens,
78
  do_sample=True,
 
64
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
65
  conversation.append({"role": "user", "content": message})
66
 
67
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
68
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
69
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
70
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
71
+ input_ids = input_ids.to(model.device)
72
 
73
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
74
  generate_kwargs = dict(
75
+ {"input_ids": input_ids},
76
  streamer=streamer,
77
  max_new_tokens=max_new_tokens,
78
  do_sample=True,