vilarin commited on
Commit
f663115
·
verified ·
1 Parent(s): 36e78de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -35,6 +35,7 @@ model = AutoModelForCausalLM.from_pretrained(
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(MODELS,trust_remote_code=True)
37
 
 
38
  @spaces.GPU
39
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
40
  print(f'message is - {message}')
@@ -48,17 +49,18 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
48
 
49
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
50
 
51
- streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
52
 
53
  generate_kwargs = dict(
54
  input_ids=input_ids,
55
- streamer=streamer,
56
- max_length=max_new_tokens,
57
- do_sample=True,
 
58
  temperature=temperature,
59
  repetition_penalty=1.2,
60
  )
61
-
62
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
63
  thread.start()
64
 
@@ -66,6 +68,13 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
66
  for new_text in streamer:
67
  buffer[-1][1] += new_text
68
  yield buffer
 
 
 
 
 
 
 
69
 
70
 
71
 
 
35
 
36
  tokenizer = AutoTokenizer.from_pretrained(MODELS,trust_remote_code=True)
37
 
38
+
39
  @spaces.GPU
40
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
41
  print(f'message is - {message}')
 
49
 
50
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
 
52
+ # streamer = TextIteratorStreamer(tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True)
53
 
54
  generate_kwargs = dict(
55
  input_ids=input_ids,
56
+ max_length=2500,
57
+ max_new_tokens=max_new_tokens,
58
+ do_sample=True,
59
+ top_k=1,
60
  temperature=temperature,
61
  repetition_penalty=1.2,
62
  )
63
+ '''
64
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
65
  thread.start()
66
 
 
68
  for new_text in streamer:
69
  buffer[-1][1] += new_text
70
  yield buffer
71
+ '''
72
+ with torch.no_grad():
73
+ outputs = model.generate(**inputs, **gen_kwargs)
74
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
75
+ results = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ return results
77
+
78
 
79
 
80