monurcan commited on
Commit
02b26ea
·
1 Parent(s): 680c2ae
Files changed (1) hide show
  1. app.py +101 -69
app.py CHANGED
@@ -1,15 +1,8 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import (
4
- AutoModelForImageTextToText,
5
- AutoProcessor,
6
- TextIteratorStreamer,
7
- )
8
- from peft import PeftModel
9
- from transformers.image_utils import load_image
10
- from threading import Thread
11
  import time
12
  import html
 
13
 
14
 
15
  def progress_bar_html(label: str) -> str:
@@ -35,63 +28,93 @@ def progress_bar_html(label: str) -> str:
35
 
36
  model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
37
 
38
- model = AutoModelForImageTextToText.from_pretrained(
39
- model_name, dtype=torch.bfloat16, device_map="auto"
40
- ).eval()
41
 
42
- processor = AutoProcessor.from_pretrained(model_name)
43
-
44
- print(f"Successfully load the model: {model}")
45
-
46
-
47
- def model_inference(input_dict, history):
48
- text = input_dict["text"]
49
- files = input_dict["files"]
50
-
51
- if len(files) > 1:
52
- images = [load_image(image) for image in files]
53
- elif len(files) == 1:
54
- images = [load_image(files[0])]
55
- else:
56
- images = []
57
 
58
- if text == "" and not images:
59
  gr.Error("Please input a query and optionally image(s).")
60
  return
61
- if text == "" and images:
62
  gr.Error("Please input a text query along with the image(s).")
63
  return
64
 
65
- messages = [
66
- {
67
- "role": "user",
68
- "content": [
69
- *[{"type": "image", "image": image} for image in images],
70
- {"type": "text", "text": text},
71
- ],
72
- }
73
- ]
74
- inputs = processor.apply_chat_template(
75
- messages,
76
- add_generation_prompt=True,
77
- tokenize=True,
78
- return_dict=True,
79
- return_tensors="pt",
80
- ).to(model.device, dtype=model.dtype)
81
- streamer = TextIteratorStreamer(
82
- processor, skip_prompt=True, skip_special_tokens=True
83
- )
84
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
85
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
86
- thread.start()
87
- buffer = ""
 
 
 
 
 
 
 
 
 
 
 
88
  yield progress_bar_html("Processing...")
89
- for new_text in streamer:
90
- escaped_new_text = html.escape(new_text)
91
- buffer += escaped_new_text
92
 
93
- time.sleep(0.001)
94
- yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  examples = [
@@ -109,15 +132,24 @@ examples = [
109
  ],
110
  ]
111
 
112
- demo = gr.ChatInterface(
113
- fn=model_inference,
114
- description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
115
- examples=examples,
116
- fill_height=True,
117
- textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
118
- stop_btn="Stop Generation",
119
- multimodal=True,
120
- cache_examples=False,
121
- )
122
-
123
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import base64
 
 
 
 
 
 
 
 
3
  import time
4
  import html
5
+ from huggingface_hub import InferenceClient
6
 
7
 
8
  def progress_bar_html(label: str) -> str:
 
28
 
29
  model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
30
 
 
 
 
31
 
32
+ def model_inference(input_dict, history, hf_token: gr.OAuthToken):
33
+ """
34
+ Use Hugging Face InferenceClient (streaming) to perform the multimodal chat completion.
35
+ Signature matches ChatInterface call pattern: (input_dict, history, *additional_inputs)
36
+ The OAuth token (from gr.LoginButton) is passed as `hf_token`.
37
+ """
38
+ text = input_dict.get("text", "")
39
+ files = input_dict.get("files", []) or []
 
 
 
 
 
 
 
40
 
41
+ if text == "" and not files:
42
  gr.Error("Please input a query and optionally image(s).")
43
  return
44
+ if text == "" and files:
45
  gr.Error("Please input a text query along with the image(s).")
46
  return
47
 
48
+ # Build the content list: images (as URLs or data URLs) followed by the text
49
+ content_list = []
50
+ for f in files:
51
+ try:
52
+ # If file looks like a URL, send as image_url
53
+ if isinstance(f, str) and f.startswith("http"):
54
+ content_list.append({"type": "image_url", "image_url": {"url": f}})
55
+ else:
56
+ # f is a local path-like object; read and convert to base64 data url
57
+ with open(f, "rb") as fh:
58
+ b = fh.read()
59
+ b64 = base64.b64encode(b).decode("utf-8")
60
+ # naive mime type: jpeg; this should work for most common images
61
+ data_url = f"data:image/jpeg;base64,{b64}"
62
+ content_list.append(
63
+ {"type": "image_url", "image_url": {"url": data_url}}
64
+ )
65
+ except Exception:
66
+ # if anything goes wrong reading the file, skip embedding that file
67
+ continue
68
+
69
+ content_list.append({"type": "text", "text": text})
70
+
71
+ messages = [{"role": "user", "content": content_list}]
72
+
73
+ if hf_token is None or not getattr(hf_token, "token", None):
74
+ gr.Error(
75
+ "Please login with a Hugging Face account (use the Login button in the sidebar)."
76
+ )
77
+ return
78
+
79
+ client = InferenceClient(token=hf_token.token, model=model_name)
80
+
81
+ response = ""
82
  yield progress_bar_html("Processing...")
 
 
 
83
 
84
+ # The API may stream tokens. Try to iterate the streaming generator and extract token deltas.
85
+ try:
86
+ stream = client.chat.completions.create(messages=messages, stream=True)
87
+ except TypeError:
88
+ # older/newer client variants: try the alternative method name
89
+ stream = client.chat_completion(messages=messages, stream=True)
90
+
91
+ for chunk in stream:
92
+ # chunk can be an object with attributes or a dict depending on client version
93
+ token = ""
94
+ try:
95
+ # attempt dict-style
96
+ if isinstance(chunk, dict):
97
+ choices = chunk.get("choices")
98
+ if choices and len(choices) > 0:
99
+ delta = choices[0].get("delta", {})
100
+ token = delta.get("content") or ""
101
+ else:
102
+ # attribute-style
103
+ choices = getattr(chunk, "choices", None)
104
+ if choices and len(choices) > 0:
105
+ delta = getattr(choices[0], "delta", None)
106
+ if isinstance(delta, dict):
107
+ token = delta.get("content") or ""
108
+ else:
109
+ token = getattr(delta, "content", "")
110
+ except Exception:
111
+ token = ""
112
+
113
+ if token:
114
+ # escape incremental token to avoid raw HTML breaking the chat box
115
+ response += html.escape(token)
116
+ time.sleep(0.001)
117
+ yield response
118
 
119
 
120
  examples = [
 
132
  ],
133
  ]
134
 
135
+ with gr.Blocks() as demo:
136
+ with gr.Sidebar():
137
+ login_btn = gr.LoginButton(label="Login with Hugging Face")
138
+
139
+ chatbot = gr.ChatInterface(
140
+ fn=model_inference,
141
+ description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
142
+ examples=examples,
143
+ fill_height=True,
144
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
145
+ stop_btn="Stop Generation",
146
+ multimodal=True,
147
+ cache_examples=False,
148
+ additional_inputs=[login_btn],
149
+ )
150
+
151
+ chatbot.render()
152
+
153
+
154
+ if __name__ == "__main__":
155
+ demo.launch(debug=True)