vilarin commited on
Commit
bf65021
1 Parent(s): 085f93a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
  from PIL import Image
10
  import gradio as gr
11
  import spaces
12
- from transformers import AutoModel, AutoTokenizer, TextIteratorStreamer
13
  import os
14
 
15
 
@@ -35,9 +35,8 @@ CSS = """
35
 
36
  model = AutoModel.from_pretrained(
37
  MODEL_ID,
38
- torch_dtype=torch.float16,
39
  trust_remote_code=True
40
- ).to(0)
41
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
42
  model.eval()
43
 
@@ -71,12 +70,18 @@ def stream_chat(message, history: list, temperature: float, max_new_tokens: int)
71
  temperature=temperature,
72
  sampling=True,
73
  tokenizer=tokenizer,
 
74
  )
75
  if temperature == 0:
76
  generate_kwargs["sampling"] = False
77
 
78
  response = model.chat(**generate_kwargs)
79
- return response
 
 
 
 
 
80
 
81
 
82
  chatbot = gr.Chatbot(height=450)
 
9
  from PIL import Image
10
  import gradio as gr
11
  import spaces
12
+ from transformers import AutoModel, AutoTokenizer
13
  import os
14
 
15
 
 
35
 
36
  model = AutoModel.from_pretrained(
37
  MODEL_ID,
 
38
  trust_remote_code=True
39
+ )
40
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
41
  model.eval()
42
 
 
70
  temperature=temperature,
71
  sampling=True,
72
  tokenizer=tokenizer,
73
+ stream=True
74
  )
75
  if temperature == 0:
76
  generate_kwargs["sampling"] = False
77
 
78
  response = model.chat(**generate_kwargs)
79
+
80
+ generated_text = ""
81
+ for new_text in response:
82
+ generated_text += new_text
83
+ yeild(new_text, flush=True, end='')
84
+
85
 
86
 
87
  chatbot = gr.Chatbot(height=450)