bielas194's picture
Update app.py
c0f4098 verified
import os
import torch
import gradio as gr
from PIL import Image
import tempfile
import shutil
from functools import partial
import traceback # <--- ADDED THIS LINE: Import the traceback module
from diffusers import StableDiffusionPipeline
from huggingface_hub import InferenceClient
# LangChain imports
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import tool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.llms import HuggingFaceHub
from langchain.agents import AgentExecutor, create_react_agent
from langchain.schema import HumanMessage, AIMessage
# --- 1. Load Stable Diffusion Pipeline (happens once at startup) ---
HF_TOKEN = os.environ.get("HF_TOKEN") # Using HF_TOKEN for consistency with HuggingFaceHub LLM
# Define the model ID for image generation
IMAGE_GEN_MODEL_ID = "segmind/tiny-sd" # Using the smaller model as it loaded successfully
print(f"Loading Stable Diffusion Pipeline directly on GPU: {IMAGE_GEN_MODEL_ID}...")
try:
pipe = StableDiffusionPipeline.from_pretrained(
IMAGE_GEN_MODEL_ID,
torch_dtype=torch.float16, # Use float16 for less VRAM usage on T4
use_safetensors=False, # Set to False for models that don't have safetensors (like tiny-sd)
token=HF_TOKEN # Pass token for potential faster model download
)
pipe.to("cuda") # Move the model to the GPU
print(f"Stable Diffusion Pipeline ({IMAGE_GEN_MODEL_ID}) loaded successfully on GPU.")
except Exception as e:
print("❌ Error loading Stable Diffusion Pipeline:")
traceback.print_exc()
pipe = None # Indicate failure to load
# --- 2. Define Custom Image Generation Tool for LangChain ---
# Use @tool decorator to make a function a LangChain tool
@tool
def image_generator(prompt: str) -> str:
"""
Generates an image from a detailed text prompt using a Stable Diffusion pipeline.
The input MUST be a detailed text description for the image to generate.
"""
if pipe is None:
return "Error: Image generation pipeline failed to load. Please check Space logs during startup."
print(f"\n--- Agent is calling image_generator with prompt: '{prompt}' ---")
try:
with torch.no_grad():
pil_image = pipe(prompt, guidance_scale=7.5, height=512, width=512).images[0]
# Save the PIL image to a temporary file, Gradio will handle displaying this path
# NOTE: LangChain tools typically return strings. For image display, we'll return
# the path, and handle its display in the Gradio UI directly based on content.
temp_dir = tempfile.mkdtemp()
image_path = os.path.join(temp_dir, "generated_image.png")
pil_image.save(image_path)
print(f"Image saved to temporary path: {image_path}")
# Return a special string prefix so Gradio knows it's an image path
return f"__IMAGE_PATH__:{image_path}"
except Exception as e:
print("Error in image_generator tool execution:")
traceback.print_exc()
return f"Error generating image: {str(e)}"
# --- 3. Define other Tools for LangChain ---
search = DuckDuckGoSearchRun()
# --- 4. Define the LangChain Agent ---
# Ensure models are loaded successfully before proceeding
if pipe is None:
raise RuntimeError("Cannot start agent as image generation pipeline failed to load. Check logs.")
# Instantiate the LLM for the agent
llm = HuggingFaceHub(
repo_id="HuggingFaceH4/zephyr-7b-beta",
huggingfacehub_api_token=HF_TOKEN, # Use HF_TOKEN directly as required by HuggingFaceHub LLM
model_kwargs={"temperature": 0.5, "max_new_tokens": 512}
)
# Create the tools list
tools = [image_generator, search]
# Define the agent prompt
# This prompt guides the LLM on how to use the tools
prompt_template = ChatPromptTemplate.from_messages(
[
("system", """You are a powerful AI assistant that can generate images and search the web.
You have access to the following tools: {tools}
Available tools: {tool_names} # <--- THIS LINE IS CRUCIAL AND MUST BE PRESENT.
When you need to generate an image, use the `image_generator` tool. Its input must be a very detailed, descriptive text string.
When you need factual information or context, use the `search` tool.
Always follow these steps:
1. Think step-by-step: Analyze the user's request and determine if you need to search or generate an image.
2. If you need to search, use the `search` tool.
3. If you need to generate an image, ensure you have enough detail. If not, ask for more or use search.
4. When you have enough information, use the `image_generator` tool.
5. Provide your final answer. If you generated an image, include the image in your final answer.
"""),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"), # This placeholder must be present
]
)
# Create the agent
agent = create_react_agent(llm, tools, prompt_template)
# Create the agent executor
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)
# --- 5. Gradio UI Integration ---
# Function to run the agent and display output
def run_agent_in_gradio(message, history):
# Convert Gradio history to LangChain chat_history format
chat_history = []
for human_msg, ai_msg in history:
chat_history.append(HumanMessage(content=human_msg))
chat_history.append(AIMessage(content=ai_msg))
try:
# THIS LINE IS CRUCIAL AND MUST INCLUDE "agent_scratchpad": []
response = agent_executor.invoke(
{"input": message, "chat_history": chat_history, "agent_scratchpad": []}
)
agent_output = response["output"]
# Check if the output is an image path from our custom tool
if agent_output.startswith("__IMAGE_PATH__:") :
image_path = agent_output.replace("__IMAGE_PATH__:", "")
# Return the Gradio Image component directly
return gr.Image(value=image_path, label="Generated Image")
else:
# Return regular text
return agent_output
except Exception as e:
print(f"Error running agent: {e}")
traceback.print_exc()
return f"❌ Agent encountered an error: {str(e)}"
# Gradio ChatInterface setup
demo = gr.ChatInterface(
fn=run_agent_in_gradio,
chatbot=gr.Chatbot(label="AI Agent"),
textbox=gr.Textbox(placeholder="Ask me to generate an image or search the web...", container=False, scale=7),
title="Intelligent Image Generator & Web Search Agent (LangChain)",
description="This agent can generate images based on prompts or search the web for information first."
)
if __name__ == "__main__":
demo.launch()