Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import base64 | |
| import pandas as pd | |
| from PIL import Image | |
| from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, VisitWebpageTool, OpenAIServerModel, tool | |
| from typing import Optional | |
| import requests | |
| from io import BytesIO | |
| import re | |
| from pathlib import Path | |
| import openai | |
| ## utilty functions | |
| def is_image_extension(filename: str) -> bool: | |
| IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp', '.svg'} | |
| ext = os.path.splitext(filename)[1].lower() # os.path.splitext(path) returns (root, ext) | |
| return ext in IMAGE_EXTS | |
| def load_file(path: str) -> list | dict: | |
| """Based on the file extension, load the file into a suitable object.""" | |
| image = None | |
| excel = None | |
| csv = None | |
| text = None | |
| ext = Path(path).suffix.lower() # same as os.path.splitext(filename)[1].lower() | |
| if ext.endswith(".png") or ext.endswith(".jpg") or ext.endswith(".jpeg"): | |
| image = Image.open(path).convert("RGB") # pillow object | |
| elif ext.endswith(".xlsx") or ext.endswith(".xls"): | |
| excel = pd.read_excel(path) # DataFrame | |
| elif ext.endswith(".csv"): | |
| csv = pd.read_csv(path) # DataFrame | |
| elif ext.endswith(".py") or ext.endswith(".txt"): | |
| with open(path, 'r') as f: | |
| text = f.read() # plain text str | |
| if image is not None: | |
| return [image] | |
| else: | |
| return {"excel": excel, "csv": csv, "raw text": text, "audio path": path} | |
| ## tools definition | |
| def download_images(image_urls: str) -> list: | |
| """ | |
| Download web images from the given comma‐separated URLs and return them in a list of PIL Images. | |
| Args: | |
| image_urls: comma‐separated list of URLs to download | |
| Returns: | |
| List of PIL.Image.Image objects | |
| """ | |
| urls = [u.strip() for u in image_urls.split(",") if u.strip()] # strip() removes whitespaces | |
| images = [] | |
| for __, url in enumerate(urls, start=1): # enumerate seems not needed... keeping it for now | |
| try: | |
| # Fetch the image bytes | |
| resp = requests.get(url, timeout=10) | |
| resp.raise_for_status() | |
| # Load into a PIL image | |
| img = Image.open(BytesIO(resp.content)).convert("RGB") | |
| images.append(img) | |
| except Exception as e: | |
| print(f"Failed to download from {url}: {e}") | |
| return images | |
| def transcribe_audio(audio_path: str) -> str: # since they gave us OpenAI API credits, we can keep using it | |
| """ | |
| Transcribe audio file using OpenAI Whisper API. | |
| Args: | |
| audio_path: path to the audio file to be transcribed. | |
| Returns: | |
| str : Transcription of the audio. | |
| """ | |
| client = openai.Client(api_key=os.getenv("OPENAI_API_KEY")) | |
| with open(audio_path, "rb") as audio: # to modify path because it is arriving from gradio | |
| transcript = client.audio.transcriptions.create( | |
| file=audio, | |
| model="whisper-1", | |
| response_format="text", | |
| ) | |
| print(transcript) | |
| try: | |
| return transcript | |
| except Exception as e: | |
| print(f"Error transcribing audio: {e}") | |
| def generate_image(prompt: str, neg_prompt: str) -> Image.Image: | |
| """ | |
| Generate an image based on a text prompt using Flux Dev. | |
| Args: | |
| prompt (str): The text prompt to generate the image from. | |
| neg_prompt (str): The negative prompt to avoid certain elements in the image. | |
| Returns: | |
| Image.Image: The generated image as a PIL Image object. | |
| """ | |
| client = OpenAI(base_url="https://api.studio.nebius.com/v1", | |
| api_key=os.environ.get("NEBIUS_API_KEY"), | |
| ) | |
| completion = client.images.generate( | |
| model="black-forest-labs/flux-dev", | |
| prompt=prompt, | |
| response_format="b64_json", | |
| extra_body={ | |
| "response_extension": "png", | |
| "width": 1024, | |
| "height": 1024, | |
| "num_inference_steps": 30, | |
| "seed": -1, | |
| "negative_prompt": neg_prompt, | |
| } | |
| ) | |
| image_data = base64.b64decode(completion.to_dict()['data'][0]['b64_json']) | |
| image = Image.open(BytesIO(image_data)) | |
| return image | |
| ## agent definition | |
| class Agent: | |
| def __init__(self, ): | |
| client = HfApiModel("google/gemma-3-27b-it", provider="nebius", api_key=os.getenv("NEBIUS_API_KEY")) | |
| self.agent = CodeAgent( | |
| model=client, | |
| tools=[DuckDuckGoSearchTool(max_results=5), VisitWebpageTool(max_output_length=20000), generate_image, download_images, transcribe_audio], | |
| additional_authorized_imports=["pandas", "PIL", "io"], | |
| planning_interval=1, | |
| max_steps=5, | |
| ) | |
| #self.agent.prompt_templates["system_prompt"] = self.agent.prompt_templates["system_prompt"] | |
| #print("System prompt:", self.agent.prompt_templates["system_prompt"]) | |
| def __call__(self, message: str, images: Optional[list[Image.Image]] = None, files: Optional[str] = None) -> str: | |
| answer = self.agent.run(message, images = images, additional_args={"files": files}) | |
| return answer | |
| ## gradio functions | |
| def respond(message, history): | |
| text = message.get("text", "") | |
| if not message.get("files"): # no files uploaded | |
| print("No files received.") | |
| message = agent(text) | |
| else: | |
| files = message.get("files", []) | |
| print(f"files received: {files}") | |
| if is_image_extension(files[0]): | |
| image = load_file(files[0]) # assuming only one file is uploaded at a time (gradio default behavior) | |
| message = agent(text, images=image) | |
| else: | |
| file = load_file(files[0]) | |
| message = agent(text, files=file) | |
| return message | |
| def initialize_agent(): | |
| agent = Agent() | |
| print("Agent initialized.") | |
| return agent | |
| ## gradio interface | |
| with gr.Blocks() as demo: | |
| global agent | |
| agent = initialize_agent() | |
| gr.ChatInterface( | |
| fn=respond, | |
| type='messages', | |
| multimodal=True, | |
| title='MultiAgent System for Screenplay Creation and Editing', | |
| show_progress='full' | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |