Nymbo commited on
Commit
81b2233
·
verified ·
1 Parent(s): 4fa442d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -287
app.py CHANGED
@@ -1,237 +1,195 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
  import os
4
  import json
5
  import base64
6
  from PIL import Image
7
  import io
8
 
9
- # Import smolagents Tool
10
- from smolagents import Tool
 
 
 
 
 
 
11
 
12
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
13
  print("Access token loaded.")
14
 
15
- # Initialize the image generation tool
16
- # This can be defined globally as it doesn't change per request
17
- try:
18
- image_generation_tool = Tool.from_space(
19
- "black-forest-labs/FLUX.1-schnell",
20
- name="image_generator",
21
- description="Generates an image from a text prompt. Use it when the user asks to 'generate an image of ...' or 'draw a picture of ...'. The input should be the descriptive prompt for the image."
22
- )
23
- print("Image generation tool loaded successfully.")
24
- except Exception as e:
25
- print(f"Error loading image generation tool: {e}")
26
- image_generation_tool = None
27
-
28
- # Function to encode image to base64
29
- def encode_image(image_path):
30
- if not image_path:
31
- print("No image path provided")
32
  return None
33
 
34
  try:
35
- print(f"Encoding image from path: {image_path}")
36
 
37
- # If it's already a PIL Image
38
- if isinstance(image_path, Image.Image):
39
- image = image_path
40
- else:
41
- # Try to open the image file
42
- image = Image.open(image_path)
43
 
44
- # Convert to RGB if image has an alpha channel (RGBA)
45
  if image.mode == 'RGBA':
46
  image = image.convert('RGB')
47
 
48
- # Encode to base64
49
  buffered = io.BytesIO()
50
- image.save(buffered, format="JPEG")
51
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
52
- print("Image encoded successfully")
53
  return img_str
54
  except Exception as e:
55
  print(f"Error encoding image: {e}")
56
  return None
57
 
 
58
  def respond(
59
- message_text, # Changed from 'message' to be explicit about text part
60
- image_files, # This will be a list of paths from gr.MultimodalTextbox
61
- history: list[list[Any, str | None]], # History can now contain complex user messages
62
- system_message,
63
  max_tokens,
64
  temperature,
65
  top_p,
66
  frequency_penalty,
67
  seed,
68
- provider,
69
- custom_api_key,
70
- custom_model,
71
- model_search_term,
72
- selected_model
73
  ):
74
- print(f"Received message text: {message_text}")
75
- print(f"Received {len(image_files) if image_files else 0} image files: {image_files}")
76
- # print(f"History: {history}") # Can be very verbose
77
- print(f"System message: {system_message}")
78
- print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
79
- print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
80
- print(f"Selected provider: {provider}")
81
- print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
82
- print(f"Selected model (custom_model): {custom_model}")
83
- print(f"Model search term: {model_search_term}")
84
- print(f"Selected model from radio: {selected_model}")
85
-
86
- # Determine which token to use
87
- token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
88
-
89
- if custom_api_key.strip() != "":
90
- print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
91
- else:
92
- print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
93
-
94
- user_text_message_lower = message_text.lower() if message_text else ""
95
-
96
- image_keywords = ["generate image", "draw a picture of", "create an image of", "make an image of"]
97
- is_image_generation_request = any(keyword in user_text_message_lower for keyword in image_keywords)
98
-
99
- if is_image_generation_request and image_generation_tool:
100
- print("Image generation request detected.")
101
- image_prompt = message_text
102
- for keyword in image_keywords:
103
- if keyword in user_text_message_lower:
104
- # Find the keyword in the original case-sensitive message text to split
105
- keyword_start_index = user_text_message_lower.find(keyword)
106
- image_prompt = message_text[keyword_start_index + len(keyword):].strip()
107
- break
108
-
109
- print(f"Extracted image prompt: {image_prompt}")
110
- if not image_prompt:
111
- yield {"type": "text", "content": "Please provide a description for the image you want to generate."}
112
- return
113
-
114
- try:
115
- generated_image_path = image_generation_tool(prompt=image_prompt)
116
- print(f"Image generated by tool, path: {generated_image_path}")
117
- yield {"type": "image", "path": str(generated_image_path)} # Ensure path is string
118
- return
119
- except Exception as e:
120
- print(f"Error during image generation tool call: {e}")
121
- yield {"type": "text", "content": f"Sorry, I couldn't generate the image. Error: {str(e)}"}
122
- return
123
- elif is_image_generation_request and not image_generation_tool:
124
- yield {"type": "text", "content": "Image generation tool is not available or failed to load."}
125
- return
126
 
127
- # If not an image generation request, proceed with text/multimodal LLM call
128
- print("Proceeding with LLM call (text or multimodal).")
129
- client = InferenceClient(token=token_to_use, provider=provider)
130
- print(f"Hugging Face Inference Client initialized with {provider} provider.")
131
 
132
- if seed == -1:
133
- seed = None
134
-
135
- # Prepare messages for LLM
136
- llm_user_content = []
137
- if message_text and message_text.strip():
138
- llm_user_content.append({"type": "text", "text": message_text})
139
-
140
- if image_files: # image_files is a list of paths from gr.MultimodalTextbox
141
- for img_path in image_files:
142
- if img_path:
143
- try:
144
- encoded_image = encode_image(img_path) # img_path is already a path
145
- if encoded_image:
146
- llm_user_content.append({
147
- "type": "image_url",
148
- "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
149
- })
150
- except Exception as e:
151
- print(f"Error encoding image for LLM: {e}")
152
 
153
- if not llm_user_content: # Should not happen if user() function filters empty messages
154
- print("No content for LLM, aborting.")
155
- yield {"type": "text", "content": "Please provide some input."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  return
157
 
158
- messages_for_llm = [{"role": "system", "content": system_message}]
159
- print("Initial messages array constructed for LLM.")
160
-
161
- for val in history: # history item is [user_content_list, assistant_response_str_or_dict]
162
- user_content_list_hist = val[0]
163
- assistant_response_hist = val[1]
164
-
165
- if user_content_list_hist:
166
- # user_content_list_hist is already in the correct format (list of dicts)
167
- messages_for_llm.append({"role": "user", "content": user_content_list_hist})
168
-
169
- if assistant_response_hist:
170
- # Assistant response could be text or an image dict from a previous tool call
171
- if isinstance(assistant_response_hist, dict) and assistant_response_hist.get("type") == "image":
172
- messages_for_llm.append({"role": "assistant", "content": [{"type": "text", "text": f"Assistant previously displayed image: {assistant_response_hist.get('path')}"}]})
173
- elif isinstance(assistant_response_hist, str):
174
- messages_for_llm.append({"role": "assistant", "content": assistant_response_hist})
175
- # Else, if it's a dict but not an image type we understand for history, we might skip or log an error
176
 
177
- messages_for_llm.append({"role": "user", "content": llm_user_content})
178
- # print(f"Full messages_for_llm: {messages_for_llm}") # Can be very verbose
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
- model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
181
- print(f"Model selected for LLM inference: {model_to_use}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- response_text = ""
184
- print(f"Sending request to {provider} provider for LLM.")
 
 
 
 
 
 
185
 
186
- parameters = {
187
- "max_tokens": max_tokens,
188
- "temperature": temperature,
189
- "top_p": top_p,
190
- "frequency_penalty": frequency_penalty,
191
- }
192
-
193
- if seed is not None:
194
- parameters["seed"] = seed
195
 
196
- try:
197
- stream = client.chat_completion(
198
- model=model_to_use,
199
- messages=messages_for_llm,
200
- stream=True,
201
- **parameters
202
- )
203
-
204
- print("Received LLM tokens: ", end="", flush=True)
205
-
206
- for chunk in stream:
207
- if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
208
- if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
209
- token_text = chunk.choices[0].delta.content
210
- if token_text:
211
- print(token_text, end="", flush=True)
212
- response_text += token_text
213
- yield {"type": "text", "content": response_text}
214
-
215
- print()
216
  except Exception as e:
217
- print(f"Error during LLM inference: {e}")
218
- response_text += f"\nError: {str(e)}"
219
- yield {"type": "text", "content": response_text}
220
 
221
- print("Completed LLM response generation.")
222
 
 
 
223
  def validate_provider(api_key, provider):
224
  if not api_key.strip() and provider != "hf-inference":
225
  return gr.update(value="hf-inference")
226
  return gr.update(value=provider)
227
 
 
228
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
229
  chatbot = gr.Chatbot(
230
  height=600,
231
  show_copy_button=True,
232
- placeholder="Select a model and begin chatting. Now supports multiple inference providers and multimodal inputs. Try 'generate image of a cat playing chess'.",
233
  layout="panel",
234
- bubble_full_width=False
235
  )
236
  print("Chatbot interface created.")
237
 
@@ -247,164 +205,199 @@ with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
247
 
248
  with gr.Accordion("Settings", open=False):
249
  system_message_box = gr.Textbox(
250
- value="You are a helpful AI assistant that can understand images and text. If asked to generate an image, respond by saying you will call the image_generator tool.",
251
- placeholder="You are a helpful assistant.",
252
- label="System Prompt"
253
  )
254
 
255
  with gr.Row():
256
  with gr.Column():
257
- max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max tokens")
258
- temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
259
- top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
260
  with gr.Column():
261
  frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
262
  seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
263
 
264
- providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
265
- provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
266
- byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here. When empty, only 'hf-inference' provider can be used.", placeholder="Enter your Hugging Face API token", type="password")
267
- custom_model_box = gr.Textbox(value="", label="Custom Model", info="(Optional) Provide a custom Hugging Face model path. Overrides any selected featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
268
- model_search_box = gr.Textbox(label="Filter Models", placeholder="Search for a featured model...", lines=1)
 
 
 
269
 
270
  models_list = [
271
- "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct",
272
- "meta-llama/Llama-3.0-70B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
273
  "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
274
  "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
275
- "mistralai/Mistral-7B-Instruct-v0.2", "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct",
276
- "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/QwQ-32B", "Qwen/Qwen2.5-Coder-32B-Instruct",
277
- "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-mini-4k-instruct",
278
  ]
279
- featured_model_radio = gr.Radio(label="Select a model below", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
280
 
281
  gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
282
 
283
- chat_history = gr.State([])
284
-
285
- def filter_models(search_term):
286
- print(f"Filtering models with search term: {search_term}")
287
- filtered = [m for m in models_list if search_term.lower() in m.lower()]
288
- print(f"Filtered models: {filtered}")
289
- return gr.update(choices=filtered)
290
-
291
- def set_custom_model_from_radio(selected):
292
- print(f"Featured model selected: {selected}")
293
- return selected
294
-
295
- def user(user_multimodal_input, history):
296
- print(f"User input (raw from gr.MultimodalTextbox): {user_multimodal_input}")
297
-
298
- text_content = user_multimodal_input.get("text", "").strip()
299
- files = user_multimodal_input.get("files", []) # These are temp file paths from Gradio
300
-
301
- if not text_content and not files:
302
- print("Empty input, skipping history append.")
303
- # Optionally, could raise gr.Error("Please enter a message or upload an image.")
304
- # For now, let's allow the bot to respond if history is not empty,
305
- # or do nothing if history is also empty.
306
- return history
307
 
308
- # Prepare content for history: a list of dicts for multimodal display
309
- history_user_entry_content = []
310
- if text_content:
311
- history_user_entry_content.append({"type": "text", "text": text_content})
 
312
 
313
- for file_path_obj in files: # file_path_obj is a FileData object from Gradio
314
- if file_path_obj and hasattr(file_path_obj, 'name') and file_path_obj.name:
315
- # Gradio's Chatbot can display images directly from file paths
316
- # We store it in a format that `respond` can also understand
317
- # The path is temporary, Gradio handles making it accessible for display
318
- history_user_entry_content.append({"type": "image_url", "image_url": {"url": file_path_obj.name}})
319
- print(f"Adding image to history entry: {file_path_obj.name}")
320
 
321
- if history_user_entry_content:
322
- history.append([history_user_entry_content, None]) # User part, Bot part (initially None)
323
-
 
 
 
324
  return history
325
-
326
  def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
327
- if not history or not history[-1][0]: # If no user message or empty user message content
328
- print("No user message to process in bot function or user message content is empty.")
329
- yield history # Return current history without processing
330
  return
331
 
332
- user_content_list = history[-1][0] # This is now a list of content dicts
 
 
333
 
334
- # Extract text and image file paths from the user_content_list for the `respond` function
335
- text_for_respond = ""
336
- image_files_for_respond = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
- for item in user_content_list:
339
- if item["type"] == "text":
340
- text_for_respond = item["text"]
341
- elif item["type"] == "image_url":
342
- image_files_for_respond.append(item["image_url"]["url"])
343
 
344
- history[-1][1] = "" # Clear placeholder for bot response / Initialize bot response
345
 
346
- # Call the respond function which is now a generator
347
- for response_chunk in respond(
348
- text_for_respond,
349
- image_files_for_respond,
350
- history[:-1], # Pass previous history
351
- system_msg, max_tokens, temperature, top_p, freq_penalty, seed,
352
- provider, api_key, custom_model, search_term, selected_model
 
 
 
 
 
 
 
 
353
  ):
354
- current_bot_response = history[-1][1]
355
- if isinstance(response_chunk, dict):
356
- if response_chunk["type"] == "text":
357
- # If current bot response is already an image dict, we can't append text.
358
- # This indicates a new text response after an image, or just text.
359
- if isinstance(current_bot_response, dict) and current_bot_response.get("type") == "image":
360
- # This case should ideally not happen if an image is the final response from a tool.
361
- # If it does, we might need to start a new bot message in history.
362
- # For now, we'll overwrite if the new chunk is text.
363
- history[-1][1] = response_chunk["content"]
364
- elif isinstance(current_bot_response, str):
365
- history[-1][1] = response_chunk["content"] # Accumulate text
366
- else: # current_bot_response is likely "" or None
367
- history[-1][1] = response_chunk["content"]
368
-
369
- elif response_chunk["type"] == "image":
370
- # Image response from tool. Gradio Chatbot displays this as an image.
371
- # The path should be accessible by Gradio.
372
- # If there was prior text content for this turn, it's now overwritten by the image.
373
- # This means a tool call that produces an image is considered the primary response for that turn.
374
- history[-1][1] = {"path": response_chunk["path"], "mime_type": "image/jpeg"} # Assuming JPEG, could be PNG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  yield history
376
-
 
 
 
 
 
 
 
 
 
 
 
377
  msg.submit(
378
  user,
379
  [msg, chatbot],
380
- [chatbot],
381
  queue=False
382
  ).then(
383
  bot,
384
  [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
385
  frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
386
  model_search_box, featured_model_radio],
387
- [chatbot]
388
  ).then(
389
- lambda: {"text": "", "files": []}, # Clear MultimodalTextbox
390
  None,
391
  [msg]
392
  )
393
 
394
  model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
395
- print("Model search box change event linked.")
396
-
397
  featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
398
- print("Featured model radio button change event linked.")
399
-
400
  byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
401
- print("BYOK textbox change event linked.")
402
-
403
  provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
404
- print("Provider radio button change event linked.")
405
 
406
  print("Gradio interface initialized.")
407
 
408
  if __name__ == "__main__":
409
  print("Launching the demo application.")
410
- demo.launch(show_api=True)
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient # Keep for direct use if needed, though agent will use its own model
3
  import os
4
  import json
5
  import base64
6
  from PIL import Image
7
  import io
8
 
9
+ # Smolagents imports
10
+ from smolagents import CodeAgent, Tool
11
+ from smolagents.models import InferenceClientModel as SmolInferenceClientModel
12
+ # We'll use PIL.Image directly for opening, AgentImage is for agent's internal typing if needed by a tool
13
+ from smolagents.gradio_ui import pull_messages_from_step # For formatting agent steps
14
+ from smolagents.memory import ActionStep, FinalAnswerStep, PlanningStep, MemoryStep # For type checking steps
15
+ from smolagents.models import ChatMessageStreamDelta # For type checking stream deltas
16
+
17
 
18
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
19
  print("Access token loaded.")
20
 
21
+ # Function to encode image to base64 (remains useful if we ever need to pass base64 to a non-smolagent component)
22
+ def encode_image(image_path_or_pil):
23
+ if not image_path_or_pil:
24
+ print("No image path or PIL Image provided")
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  return None
26
 
27
  try:
28
+ # print(f"Encoding image: {type(image_path_or_pil)}") # Debug
29
 
30
+ if isinstance(image_path_or_pil, Image.Image):
31
+ image = image_path_or_pil
32
+ else: # Assuming it's a path
33
+ image = Image.open(image_path_or_pil)
 
 
34
 
 
35
  if image.mode == 'RGBA':
36
  image = image.convert('RGB')
37
 
 
38
  buffered = io.BytesIO()
39
+ image.save(buffered, format="JPEG") # JPEG is generally smaller for transfer
40
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
41
+ # print("Image encoded successfully") # Debug
42
  return img_str
43
  except Exception as e:
44
  print(f"Error encoding image: {e}")
45
  return None
46
 
47
+ # This function will now set up and run the smolagent
48
  def respond(
49
+ message_text, # Text from MultimodalTextbox
50
+ image_file_paths, # List of file paths from MultimodalTextbox
51
+ gradio_history: list[tuple[str, str]], # Gradio history (for context if needed, agent is stateless per call here)
52
+ system_message_for_agent, # System prompt for the main LLM agent
53
  max_tokens,
54
  temperature,
55
  top_p,
56
  frequency_penalty,
57
  seed,
58
+ provider_for_agent_llm,
59
+ api_key_for_agent_llm,
60
+ model_id_for_agent_llm,
61
+ model_search_term, # Unused directly by agent logic
62
+ selected_model_for_agent_llm # Fallback model ID
63
  ):
64
+ print(f"Respond function called. Message: '{message_text}', Images: {image_file_paths}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ token_to_use = api_key_for_agent_llm if api_key_for_agent_llm.strip() != "" else ACCESS_TOKEN
67
+ model_to_use = model_id_for_agent_llm.strip() if model_id_for_agent_llm.strip() != "" else selected_model_for_agent_llm
 
 
68
 
69
+ # --- Initialize the LLM for the CodeAgent ---
70
+ agent_llm_params = {
71
+ "model_id": model_to_use,
72
+ "token": token_to_use,
73
+ # smolagents's InferenceClientModel uses max_tokens for max_new_tokens
74
+ "max_tokens": max_tokens,
75
+ "temperature": temperature if temperature > 0.01 else None, # Some models require temp > 0
76
+ "top_p": top_p if top_p < 1.0 else None, # Often 1.0 means no top_p
77
+ "seed": seed if seed != -1 else None,
78
+ }
79
+ if provider_for_agent_llm and provider_for_agent_llm != "hf-inference":
80
+ agent_llm_params["provider"] = provider_for_agent_llm
 
 
 
 
 
 
 
 
81
 
82
+ # HFIC specific params, add if not default and supported
83
+ if frequency_penalty != 0.0:
84
+ agent_llm_params["frequency_penalty"] = frequency_penalty
85
+
86
+ agent_llm = SmolInferenceClientModel(**agent_llm_params)
87
+ print(f"Smolagents LLM for agent initialized: model='{model_to_use}', provider='{provider_for_agent_llm or 'default'}'")
88
+
89
+ # --- Define Tools for the Agent ---
90
+ agent_tools = []
91
+ try:
92
+ image_gen_tool = Tool.from_space(
93
+ space_id="black-forest-labs/FLUX.1-schnell",
94
+ name="image_generator",
95
+ description="Generates an image from a textual prompt. Input is a single string argument named 'prompt'. Output is an image file path.",
96
+ token=token_to_use
97
+ )
98
+ agent_tools.append(image_gen_tool)
99
+ print("Image generation tool loaded: black-forest-labs/FLUX.1-schnell")
100
+ except Exception as e:
101
+ print(f"Error loading image generation tool: {e}")
102
+ yield f"Error: Could not load image generation tool. {e}"
103
  return
104
 
105
+ # --- Initialize the CodeAgent ---
106
+ # If system_message_for_agent is empty, CodeAgent will use its default.
107
+ # The default is usually good as it explains how to use tools.
108
+ agent = CodeAgent(
109
+ tools=agent_tools,
110
+ model=agent_llm,
111
+ system_prompt=system_message_for_agent if system_message_for_agent and system_message_for_agent.strip() else None,
112
+ # add_base_tools=True, # Consider adding Python interpreter, etc.
113
+ stream_outputs=True # Important for Gradio streaming
114
+ )
115
+ print("Smolagents CodeAgent initialized.")
 
 
 
 
 
 
 
116
 
117
+ # --- Prepare task and image inputs for the agent ---
118
+ agent_task_text = message_text
119
+
120
+ pil_images_for_agent = []
121
+ if image_file_paths:
122
+ for file_path in image_file_paths:
123
+ try:
124
+ pil_images_for_agent.append(Image.open(file_path))
125
+ except Exception as e:
126
+ print(f"Error opening image file {file_path} for agent: {e}")
127
+
128
+ print(f"Agent task: '{agent_task_text}'")
129
+ if pil_images_for_agent:
130
+ print(f"Passing {len(pil_images_for_agent)} image(s) to agent.")
131
 
132
+ # --- Run the agent and stream response ---
133
+ # Agent is reset each turn. For conversational memory, agent instance
134
+ # would need to be stored in session_state and agent.run(..., reset=False) used.
135
+
136
+ current_agent_response_text = ""
137
+ try:
138
+ # The agent.run method returns a generator when stream=True
139
+ for step_item in agent.run(
140
+ task=agent_task_text,
141
+ images=pil_images_for_agent,
142
+ stream=True,
143
+ reset=True # Explicitly reset for stateless operation per call
144
+ ):
145
+ if isinstance(step_item, ChatMessageStreamDelta):
146
+ if step_item.content:
147
+ current_agent_response_text += step_item.content
148
+ yield current_agent_response_text # Yield accumulated text
149
+
150
+ elif isinstance(step_item, (ActionStep, PlanningStep, FinalAnswerStep)):
151
+ # A structured step. Format it for Gradio.
152
+ # pull_messages_from_step yields gr.ChatMessage objects.
153
+ for gradio_chat_msg in pull_messages_from_step(step_item, skip_model_outputs=agent.stream_outputs):
154
+ # The 'bot' function will handle these gr.ChatMessage objects.
155
+ yield gradio_chat_msg # Yield the gr.ChatMessage object directly
156
+ current_agent_response_text = "" # Reset text buffer after a structured step
157
+
158
+ # else:
159
+ # print(f"Unhandled stream item type: {type(step_item)}") # Debug
160
 
161
+ # If there's any remaining text not part of a gr.ChatMessage, yield it.
162
+ # This usually shouldn't happen if stream_to_gradio logic is followed,
163
+ # as text deltas should be part of the last gr.ChatMessage or yielded before it.
164
+ # However, if the agent's final textual answer comes as pure deltas after all steps.
165
+ if current_agent_response_text and not isinstance(step_item, FinalAnswerStep):
166
+ # Check if the last yielded item already contains this text
167
+ if not (isinstance(step_item, gr.ChatMessage) and step_item.content == current_agent_response_text):
168
+ yield current_agent_response_text
169
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  except Exception as e:
172
+ error_message = f"Error during agent execution: {str(e)}"
173
+ print(error_message)
174
+ yield error_message # Yield the error message to be displayed in UI
175
 
176
+ print("Agent run completed.")
177
 
178
+
179
+ # Function to validate provider selection based on BYOK
180
  def validate_provider(api_key, provider):
181
  if not api_key.strip() and provider != "hf-inference":
182
  return gr.update(value="hf-inference")
183
  return gr.update(value=provider)
184
 
185
+ # GRADIO UI
186
  with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
187
  chatbot = gr.Chatbot(
188
  height=600,
189
  show_copy_button=True,
190
+ placeholder="Select a model and begin chatting. Now uses smolagents with tools!",
191
  layout="panel",
192
+ bubble_full_width=False # For better display of images/files
193
  )
194
  print("Chatbot interface created.")
195
 
 
205
 
206
  with gr.Accordion("Settings", open=False):
207
  system_message_box = gr.Textbox(
208
+ value="You are a helpful AI assistant. You can generate images if asked. Be precise with your prompts for image generation.",
209
+ placeholder="You are a helpful AI assistant.",
210
+ label="System Prompt for Agent"
211
  )
212
 
213
  with gr.Row():
214
  with gr.Column():
215
+ max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max New Tokens")
216
+ temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature")
217
+ top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P")
218
  with gr.Column():
219
  frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
220
  seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
221
 
222
+ providers_list = [
223
+ "hf-inference", "cerebras", "together", "sambanova", "novita",
224
+ "cohere", "fireworks-ai", "hyperbolic", "nebius",
225
+ ]
226
+ provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider for Agent's LLM")
227
+ byok_textbox = gr.Textbox(value="", label="BYOK (Your HF Token or Provider API Key)", info="Enter API key for the selected provider. Uses HF_TOKEN if empty.", placeholder="Enter your API token", type="password")
228
+ custom_model_box = gr.Textbox(value="", label="Custom Model ID for Agent's LLM", info="(Optional) Provide a custom model ID. Overrides featured model.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
229
+ model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search for a featured model...", lines=1)
230
 
231
  models_list = [
232
+ "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.0-70B-Instruct",
233
+ "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
234
  "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
235
  "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
236
+ "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-3B-Instruct",
237
+ "Qwen/Qwen2.5-Coder-32B-Instruct", "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct",
 
238
  ]
239
+ featured_model_radio = gr.Radio(label="Select a Featured Model for Agent's LLM", choices=models_list, value="meta-llama/Llama-3.3-70B-Instruct", interactive=True)
240
 
241
  gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")
242
 
243
+ # Chat history state (using gr.State to manage it properly)
244
+ # The chatbot's value itself will be the history display.
245
+ # We might need a separate gr.State if agent needs to be conversational across turns.
246
+ # For now, agent is stateless per turn.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ # Function for the chat interface
249
+ def user(user_multimodal_input_dict, history):
250
+ print(f"User input: {user_multimodal_input_dict}")
251
+ text_content = user_multimodal_input_dict.get("text", "")
252
+ files = user_multimodal_input_dict.get("files", [])
253
 
254
+ user_display_parts = []
255
+ if text_content and text_content.strip():
256
+ user_display_parts.append(text_content)
257
+ for file_path_obj in files: # file_path_obj is a tempfile._TemporaryFileWrapper
258
+ user_display_parts.append((file_path_obj.name, os.path.basename(file_path_obj.name)))
 
 
259
 
260
+ if not user_display_parts:
261
+ return history
262
+
263
+ # Append the user's multimodal message to history for display
264
+ # The actual data (dict) is passed to `bot` function separately.
265
+ history.append([user_display_parts if len(user_display_parts) > 1 else user_display_parts[0], None])
266
  return history
267
+
268
  def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
269
+ if not history or not history[-1][0]: # If no user input
270
+ yield history
 
271
  return
272
 
273
+ # The user's input (text and list of file paths) is in history[-1][0]
274
+ # If `user` function stores the dict:
275
+ raw_user_input_dict = history[-1][0] if isinstance(history[-1][0], dict) else {"text": str(history[-1][0]), "files": []}
276
 
277
+ # If `user` function stores formatted display parts:
278
+ # We need to reconstruct or rely on msg input to bot.
279
+ # For now, assuming msg.submit passes the raw dict.
280
+ # Let's adjust the Gradio flow to pass `msg` directly to `bot` as well.
281
+
282
+ # The `msg` variable in `msg.submit` holds the raw MultimodalTextbox output.
283
+ # We need to pass this raw dict to `respond`.
284
+ # The `history` is for display.
285
+
286
+ # This part is tricky as `bot` gets `history` which is already formatted for display.
287
+ # A common pattern is to pass `msg` (raw input) also to `bot`.
288
+ # Let's assume `history[-1][0]` contains enough info or we adjust `user` fn.
289
+ # For simplicity, let's assume `user` stores the raw dict if needed,
290
+ # or `bot` can parse `history[-1][0]` if it's a string/list of tuples.
291
+
292
+ # Let's assume `history[-1][0]` is the raw `user_multimodal_input_dict`
293
+ # This means the `user` function must append it like: `history.append([user_multimodal_input_dict, None])`
294
+ # And the chatbot will display `str(user_multimodal_input_dict)`.
295
+ # This is what the current `user` function does.
296
+
297
+ user_input_data = history[-1][0] # This should be the dict from MultimodalTextbox
298
+ text_input_for_agent = user_input_data.get("text", "")
299
+ # Files from MultimodalTextbox are temp file paths
300
+ image_file_paths_for_agent = [f.name for f in user_input_data.get("files", []) if hasattr(f, 'name')]
301
 
 
 
 
 
 
302
 
303
+ history[-1][1] = "" # Initialize assistant's part for streaming
304
 
305
+ # Buffer for current text stream from agent
306
+ # Handles both pure text deltas and text content from gr.ChatMessage
307
+ current_text_for_turn = ""
308
+
309
+ for item in respond(
310
+ message_text=text_input_for_agent,
311
+ image_file_paths=image_file_paths_for_agent,
312
+ gradio_history=history[:-1], # Pass previous turns for context if agent uses it
313
+ system_message_for_agent=system_msg,
314
+ max_tokens=max_tokens, temperature=temperature, top_p=top_p,
315
+ frequency_penalty=freq_penalty, seed=seed,
316
+ provider_for_agent_llm=provider, api_key_for_agent_llm=api_key,
317
+ model_id_for_agent_llm=custom_model,
318
+ model_search_term=search_term, # unused
319
+ selected_model_for_agent_llm=selected_model
320
  ):
321
+ if isinstance(item, str): # LLM text delta from agent's thought or textual answer
322
+ current_text_for_turn = item
323
+ history[-1][1] = current_text_for_turn
324
+ elif isinstance(item, gr.ChatMessage):
325
+ # This is a structured step (thought, tool output, image, etc.)
326
+ # We need to append this to the history as a new message or part of current message.
327
+ # For simplicity, let's append its string content to the current turn's assistant message.
328
+ # If it's an image/file, we'll represent it as a markdown link.
329
+ if isinstance(item.content, str):
330
+ current_text_for_turn = item.content # Replace if it's a full message
331
+ elif isinstance(item.content, dict) and "path" in item.content:
332
+ # This is typically an image or audio file
333
+ file_path = item.content["path"]
334
+ # We need to make this file accessible to Gradio if it's temporary from agent
335
+ # For now, just put a placeholder.
336
+ # If it's an output from a tool, the path might be relative to where smolagents saves it.
337
+ # Gradio needs an absolute path or a URL.
338
+ # A common pattern is to copy temp files to a static dir served by Gradio or use gr.File.
339
+ # For now, let's assume Gradio can handle local paths if they are in a folder it knows.
340
+ # We'll display it as a tuple for Gradio Chatbot.
341
+ # This means history[-1][1] needs to become a list.
342
+
343
+ # If current_text_for_turn is not empty, make history[-1][1] a list
344
+ if current_text_for_turn and not isinstance(history[-1][1], list):
345
+ history[-1][1] = [current_text_for_turn]
346
+ elif not current_text_for_turn and not isinstance(history[-1][1], list):
347
+ history[-1][1] = []
348
+
349
+
350
+ alt_text = item.metadata.get("title", os.path.basename(file_path)) if item.metadata else os.path.basename(file_path)
351
+
352
+ # Add as new component to the list for current assistant message
353
+ if isinstance(history[-1][1], list):
354
+ history[-1][1].append((file_path, alt_text))
355
+ else: # Should have been made a list above
356
+ history[-1][1] = [(file_path, alt_text)]
357
+
358
+ current_text_for_turn = "" # Reset text buffer after a file
359
+
360
+ # If it's not a delta, but a full message, replace the current text
361
+ if not isinstance(history[-1][1], list): # if it hasn't become a list due to file
362
+ history[-1][1] = current_text_for_turn
363
+
364
  yield history
365
+
366
+ # Event handlers
367
+ # `msg.submit`'s first argument is the function to call.
368
+ # Its `inputs` are the Gradio components whose values are passed to the function.
369
+ # Its `outputs` are the Gradio components that are updated by the function's return value.
370
+ # The `user` function now appends the raw dict from MultimodalTextbox to history.
371
+ # The `bot` function takes this history.
372
+
373
+ # When msg is submitted:
374
+ # 1. Call `user` to update history with user's input. Output is `chatbot`.
375
+ # 2. Then call `bot` with the updated history. Output is `chatbot`.
376
+ # 3. Then clear `msg`
377
  msg.submit(
378
  user,
379
  [msg, chatbot],
380
+ [chatbot], # `user` returns the new history, updating the chatbot display
381
  queue=False
382
  ).then(
383
  bot,
384
  [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
385
  frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
386
  model_search_box, featured_model_radio],
387
+ [chatbot] # `bot` yields history updates, streaming to chatbot
388
  ).then(
389
+ lambda: {"text": "", "files": []}, # Clear MultimodalTextbox
390
  None,
391
  [msg]
392
  )
393
 
394
  model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
 
 
395
  featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
 
 
396
  byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
 
 
397
  provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
 
398
 
399
  print("Gradio interface initialized.")
400
 
401
  if __name__ == "__main__":
402
  print("Launching the demo application.")
403
+ demo.launch(show_api=False) # show_api=False for cleaner launch, True for API docs