vilarin commited on
Commit
c34cc0a
1 Parent(s): 29eb5bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -32,7 +32,7 @@ model = AutoModelForCausalLM.from_pretrained(
32
  trust_remote_code=True,
33
  ).eval()
34
 
35
- tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True)
36
 
37
  class StopOnTokens(StoppingCriteria):
38
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@@ -56,7 +56,7 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
56
  print(f"Conversation is -\n{conversation}")
57
  stop = StopOnTokens()
58
 
59
- input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(model.device)
60
  #input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
61
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
62
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
@@ -64,8 +64,8 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
64
 
65
  generate_kwargs = dict(
66
  input_ids=input_ids,
67
- max_new_tokens=max_new_tokens,
68
  streamer=streamer,
 
69
  do_sample=True,
70
  top_k=1,
71
  temperature=temperature,
 
32
  trust_remote_code=True,
33
  ).eval()
34
 
35
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True, use_fast=False)
36
 
37
  class StopOnTokens(StoppingCriteria):
38
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
56
  print(f"Conversation is -\n{conversation}")
57
  stop = StopOnTokens()
58
 
59
+ input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(next(model.parameters()).device)
60
  #input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
61
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
62
  eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
 
64
 
65
  generate_kwargs = dict(
66
  input_ids=input_ids,
 
67
  streamer=streamer,
68
+ max_new_tokens=max_new_tokens,
69
  do_sample=True,
70
  top_k=1,
71
  temperature=temperature,