Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from PIL import Image | |
import tempfile | |
import shutil | |
from functools import partial | |
import traceback # <--- ADDED THIS LINE: Import the traceback module | |
from diffusers import StableDiffusionPipeline | |
from huggingface_hub import InferenceClient | |
# LangChain imports | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.tools import tool | |
from langchain_community.tools import DuckDuckGoSearchRun | |
from langchain_community.llms import HuggingFaceHub | |
from langchain.agents import AgentExecutor, create_react_agent | |
from langchain.schema import HumanMessage, AIMessage | |
# --- 1. Load Stable Diffusion Pipeline (happens once at startup) --- | |
HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM | |
# Define the model ID for image generation | |
IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully | |
print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...") | |
try: | |
pipe = StableDiffusionPipeline.from_pretrained( | |
IMAGE_GEN_MODEL_ID, | |
torch_dtype=torch.float16, # Use float16 for less VRAM usage on T4 | |
use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd) | |
token=HF_TOKEN # Pass token for potential faster model download | |
) | |
pipe.to("cuda") # Move the model to the GPU | |
print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.") | |
except Exception as e: | |
print("β Error loading Stable Diffusion Pipeline:") | |
traceback.print_exc() | |
pipe = None # Indicate failure to load | |
# --- 2. Define Custom Image Generation Tool for LangChain --- | |
# Use @tool decorator to make a function a LangChain tool | |
def image_generator(prompt: str) -> str: | |
""" | |
Generates an image from a detailed text prompt using a Stable Diffusion pipeline. | |
The input MUST be a detailed text description for the image to generate. | |
""" | |
if pipe is None: | |
return "Error: Image generation pipeline failed to load. Please check Space logs during startup." | |
print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---") | |
try: | |
with torch.no_grad(): | |
pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0] | |
# Save the PIL image to a temporary file, Gradio will handle displaying this path | |
# NOTE: LangChain tools typically return strings. For image display, we'll return | |
# the path, and handle its display in the Gradio UI directly based on content. | |
temp_dir = tempfile.mkdtemp() | |
image_path = os.path.join(temp_dir, "generated_image.png") | |
pil_image.save(image_path) | |
print(f"Image saved to temporary path: {image_path}") | |
# Return a special string prefix so Gradio knows it's an image path | |
return f"__IMAGE_PATH__:{image_path}" | |
except Exception as e: | |
print("Error in image_generator tool execution:") | |
traceback.print_exc() | |
return f"Error generating image: {str(e)}" | |
# --- 3. Define other Tools for LangChain --- | |
search = DuckDuckGoSearchRun() | |
# --- 4. Define the LangChain Agent --- | |
# Ensure models are loaded successfully before proceeding | |
if pipe is None: | |
raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.") | |
# Instantiate the LLM for the agent | |
llm = HuggingFaceHub( | |
repo_id="HuggingFaceH4/zephyr-7b-beta", | |
huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM | |
model_kwargs={"temperature": 0.5, "max_new_tokens": 512} | |
) | |
# Create the tools list | |
tools = [image_generator, search] | |
# Define the agent prompt | |
# This prompt guides the LLM on how to use the tools | |
prompt_template = ChatPromptTemplate.from_messages( | |
[ | |
("system", """You are a powerful AI assistant that can generate images and search the web. | |
You have access to the following tools: {tools} | |
Available tools: {tool_names} # <--- THIS LINE IS CRUCIAL AND MUST BE PRESENT. | |
When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string. | |
When you need factual information or context, use the `search` tool. | |
Always follow these steps: | |
1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image. | |
2. If you need to search, use the `search` tool. | |
3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search. | |
4. When you have enough information, use the `image_generator` tool. | |
5. Provide your final answer. If you generated an image, include the image in your final answer. | |
"""), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{input}"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), # This placeholder must be present | |
] | |
) | |
# Create the agent | |
agent = create_react_agent(llm, tools, prompt_template) | |
# Create the agent executor | |
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True) | |
# --- 5. Gradio UI Integration --- | |
# Function to run the agent and display output | |
def run_agent_in_gradio(message, history): | |
# Convert Gradio history to LangChain chat_history format | |
chat_history = [] | |
for human_msg, ai_msg in history: | |
chat_history.append(HumanMessage(content=human_msg)) | |
chat_history.append(AIMessage(content=ai_msg)) | |
try: | |
# THIS LINE IS CRUCIAL AND MUST INCLUDE "agent_scratchpad": [] | |
response = agent_executor.invoke( | |
{"input": message, "chat_history": chat_history, "agent_scratchpad": []} | |
) | |
agent_output = response["output"] | |
# Check if the output is an image path from our custom tool | |
if agent_output.startswith("__IMAGE_PATH__:") : | |
image_path = agent_output.replace("__IMAGE_PATH__:", "") | |
# Return the Gradio Image component directly | |
return gr.Image(value=image_path, label="Generated Image") | |
else: | |
# Return regular text | |
return agent_output | |
except Exception as e: | |
print(f"Error running agent: {e}") | |
traceback.print_exc() | |
return f"β Agent encountered an error: {str(e)}" | |
# Gradio ChatInterface setup | |
demo = gr.ChatInterface( | |
fn=run_agent_in_gradio, | |
chatbot=gr.Chatbot(label="AI Agent"), | |
textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7), | |
title="Intelligent Image Generator & Web Search Agent (LangChain)", | |
description="This agent can generate images based on prompts or search the web for information first." | |
) | |
if __name__ == "__main__": | |
demo.launch() |