phamngoctukts commited on
Commit
ff4edf1
1 Parent(s): deacc32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -8
app.py CHANGED
@@ -18,12 +18,10 @@ import os
18
  tk = token = os.environ.get("HF_TOKEN")
19
  login(tk)
20
  model_id = "meta-llama/Llama-3.2-1B"
21
- text2text = pipeline(
22
- "text2text",
23
- model=model_id,
24
- torch_dtype=torch.bfloat16,
25
- device_map="auto"
26
- )
27
  r = sr.Recognizer()
28
 
29
  @dataclass
@@ -85,7 +83,7 @@ def process_audio(audio:tuple, state:AppState):
85
  return gr.Audio(recording=False), state
86
  return None, state
87
 
88
- def response(state:AppState):
89
  if not state.pause_detected and not state.started_talking:
90
  return None, AppState()
91
  audio_buffer = BytesIO()
@@ -107,7 +105,54 @@ def response(state:AppState):
107
  if textin != "":
108
  print("Đang nghĩ...")
109
  textout=str(text2text(textin))
110
- textout = textout.replace('*','')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  state.conversation.append({"role": "user", "content": "Trợ lý: " + textout})
112
  if textout != "":
113
  print("Đang đọc...")
 
18
  tk = token = os.environ.get("HF_TOKEN")
19
  login(tk)
20
  model_id = "meta-llama/Llama-3.2-1B"
21
+ ckpt = "meta-llama/Llama-3.2-11B-Vision-Instruct"
22
+ model = MllamaForConditionalGeneration.from_pretrained(ckpt,
23
+ torch_dtype=torch.bfloat16).to("cpu")
24
+ processor = AutoProcessor.from_pretrained(ckpt)
 
 
25
  r = sr.Recognizer()
26
 
27
  @dataclass
 
83
  return gr.Audio(recording=False), state
84
  return None, state
85
 
86
+ def response(state:AppState, message, history, max_new_tokens=250):
87
  if not state.pause_detected and not state.started_talking:
88
  return None, AppState()
89
  audio_buffer = BytesIO()
 
105
  if textin != "":
106
  print("Đang nghĩ...")
107
  textout=str(text2text(textin))
108
+
109
+
110
+ for i, msg in enumerate(history):
111
+ if isinstance(msg[0], tuple):
112
+ messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
113
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
114
+ images.append(Image.open(msg[0][0]).convert("RGB"))
115
+ elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
116
+ # messages are already handled
117
+ pass
118
+ elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
119
+ messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
120
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
121
+
122
+ # add current message
123
+ if len(message["files"]) == 1:
124
+ if isinstance(message["files"][0], str): # examples
125
+ image = Image.open(message["files"][0]).convert("RGB")
126
+ else: # regular input
127
+ image = Image.open(message["files"][0]["path"]).convert("RGB")
128
+ images.append(image)
129
+ messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
130
+ else:
131
+ messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
132
+
133
+
134
+ texts = processor.apply_chat_template(messages, add_generation_prompt=True)
135
+
136
+ if images == []:
137
+ inputs = processor(text=texts, return_tensors="pt").to("cpu")
138
+ else:
139
+ inputs = processor(text=texts, images=images, return_tensors="pt").to("cpu")
140
+ streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
141
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
142
+ generated_text = streamer
143
+
144
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
145
+ thread.start()
146
+ buffer = ""
147
+
148
+ for new_text in streamer:
149
+ buffer += new_text
150
+ generated_text_without_prompt = buffer
151
+ time.sleep(0.01)
152
+ yield buffer
153
+
154
+
155
+ textout = generated_text.replace('*','')
156
  state.conversation.append({"role": "user", "content": "Trợ lý: " + textout})
157
  if textout != "":
158
  print("Đang đọc...")