Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torch.nn as nn | |
| import torchvision.transforms.functional as TVF | |
| from transformers import AutoModel, AutoProcessor, AutoTokenizer, AutoModelForConditionalGeneration, PreTrainedTokenizer, PreTrainedTokenizerFast | |
| # Define constants | |
| TITLE = "<h1><center>Enhanced Image Captioning Studio</center></h1>" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Pre-defined caption types with templates | |
| CAPTION_TYPE_MAP = { | |
| "Descriptive": [ | |
| "Write a descriptive caption for this image in a formal tone.", | |
| "Write a descriptive caption for this image in a formal tone within {word_count} words.", | |
| "Write a {length} descriptive caption for this image in a formal tone.", | |
| ], | |
| "Descriptive (Informal)": [ | |
| "Write a descriptive caption for this image in a casual tone.", | |
| "Write a descriptive caption for this image in a casual tone within {word_count} words.", | |
| "Write a {length} descriptive caption for this image in a casual tone.", | |
| ], | |
| "AI Generation Prompt": [ | |
| "Write a detailed prompt for AI image generation based on this image.", | |
| "Write a detailed prompt for AI image generation based on this image within {word_count} words.", | |
| "Write a {length} prompt for AI image generation based on this image.", | |
| ], | |
| "MidJourney": [ | |
| "Write a MidJourney prompt for this image.", | |
| "Write a MidJourney prompt for this image within {word_count} words.", | |
| "Write a {length} MidJourney prompt for this image.", | |
| ], | |
| "Stable Diffusion": [ | |
| "Write a Stable Diffusion prompt for this image.", | |
| "Write a Stable Diffusion prompt for this image within {word_count} words.", | |
| "Write a {length} Stable Diffusion prompt for this image.", | |
| ], | |
| "Art Critic": [ | |
| "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.", | |
| "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.", | |
| "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.", | |
| ], | |
| "Product Listing": [ | |
| "Write a caption for this image as though it were a product listing.", | |
| "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.", | |
| "Write a {length} caption for this image as though it were a product listing.", | |
| ], | |
| "Social Media Post": [ | |
| "Write a caption for this image as if it were being used for a social media post.", | |
| "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.", | |
| "Write a {length} caption for this image as if it were being used for a social media post.", | |
| ], | |
| "Tag List": [ | |
| "Write a list of tags for this image.", | |
| "Write a list of tags for this image within {word_count} words.", | |
| "Write a {length} list of tags for this image.", | |
| ], | |
| "Technical Analysis": [ | |
| "Provide a technical analysis of this image including camera details, lighting, composition, and quality.", | |
| "Provide a technical analysis of this image including camera details, lighting, composition, and quality within {word_count} words.", | |
| "Provide a {length} technical analysis of this image including camera details, lighting, composition, and quality.", | |
| ], | |
| } | |
| class ImageAdapter(nn.Module): | |
| def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool): | |
| super().__init__() | |
| self.deep_extract = deep_extract | |
| if self.deep_extract: | |
| input_features = input_features * 5 | |
| self.linear1 = nn.Linear(input_features, output_features) | |
| self.activation = nn.GELU() | |
| self.linear2 = nn.Linear(output_features, output_features) | |
| self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) | |
| self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features)) | |
| # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>) | |
| self.other_tokens = nn.Embedding(3, output_features) | |
| self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) | |
| def forward(self, vision_outputs: torch.Tensor): | |
| if self.deep_extract: | |
| x = torch.concat(( | |
| vision_outputs[-2], | |
| vision_outputs[3], | |
| vision_outputs[7], | |
| vision_outputs[13], | |
| vision_outputs[20], | |
| ), dim=-1) | |
| assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" | |
| assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}" | |
| else: | |
| x = vision_outputs[-2] | |
| x = self.ln1(x) | |
| if self.pos_emb is not None: | |
| assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" | |
| x = x + self.pos_emb | |
| x = self.linear1(x) | |
| x = self.activation(x) | |
| x = self.linear2(x) | |
| # <|image_start|>, IMAGE, <|image_end|> | |
| other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)) | |
| assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" | |
| x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) | |
| return x | |
| def get_eot_embedding(self): | |
| return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0) | |
| # Model loading functions | |
| def load_siglip_model(): | |
| print("Loading SigLIP model...") | |
| model_path = "google/siglip-so400m-patch14-384" | |
| processor = AutoProcessor.from_pretrained(model_path) | |
| model = AutoModel.from_pretrained(model_path) | |
| model = model.vision_model | |
| model.eval() | |
| model.requires_grad_(False) | |
| model.to(DEVICE) | |
| return model, processor | |
| def load_blip_model(): | |
| print("Loading BLIP model...") | |
| processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
| model = AutoModelForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
| model.to(DEVICE) | |
| model.eval() | |
| return model, processor | |
| # Initialize models (with optional lazy loading) | |
| class ModelManager: | |
| def __init__(self): | |
| self.blip_model = None | |
| self.blip_processor = None | |
| self.siglip_model = None | |
| self.siglip_processor = None | |
| self.image_adapter = None | |
| self.llm_model = None | |
| self.tokenizer = None | |
| self.models_loaded = False | |
| def load_models(self): | |
| if not self.models_loaded: | |
| # Load BLIP model for basic captioning | |
| self.blip_model, self.blip_processor = load_blip_model() | |
| # For more advanced captioning, set up paths to load custom models | |
| # In a real implementation, you would load the full pipeline with proper paths | |
| # For now, we'll use BLIP for both simple and advanced operations | |
| self.models_loaded = True | |
| return self | |
| model_manager = ModelManager() | |
| def generate_basic_caption(image, prompt="a detailed caption of this image:"): | |
| """Generate a basic caption using BLIP model""" | |
| model_manager.load_models() | |
| inputs = model_manager.blip_processor(image, prompt, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model_manager.blip_model.generate(**inputs, max_new_tokens=100) | |
| return model_manager.blip_processor.decode(outputs[0], skip_special_tokens=True) | |
| def generate_advanced_description(image, caption_type, caption_length, detail_level, emotion_focus, style_focus, extra_options, custom_prompt): | |
| """Generate an advanced description using multiple targeted prompts""" | |
| if image is None: | |
| return "Please upload an image to generate a description." | |
| try: | |
| # Load models if not already loaded | |
| model_manager.load_models() | |
| # Process caption parameters | |
| length = None if caption_length == "any" else caption_length | |
| if isinstance(length, str): | |
| try: | |
| length = int(length) | |
| except ValueError: | |
| pass | |
| # Build prompt based on caption type and parameters | |
| if length is None: | |
| map_idx = 0 | |
| elif isinstance(length, int): | |
| map_idx = 1 | |
| else: | |
| map_idx = 2 | |
| prompt_str = CAPTION_TYPE_MAP.get(caption_type, CAPTION_TYPE_MAP["Descriptive"])[map_idx] | |
| # Add extra options | |
| if extra_options: | |
| prompt_str += " " + " ".join(extra_options) | |
| # Replace placeholders in the prompt | |
| prompt_str = prompt_str.format(length=caption_length, word_count=caption_length) | |
| # Override with custom prompt if provided | |
| if custom_prompt and custom_prompt.strip(): | |
| prompt_str = custom_prompt.strip() | |
| print(f"Using prompt: {prompt_str}") | |
| # Generate captions with different aspects based on detail level | |
| with torch.no_grad(): | |
| # 1. Basic caption | |
| basic_caption = generate_basic_caption(image, prompt_str) | |
| descriptions = [] | |
| descriptions.append(("Basic Caption", basic_caption)) | |
| # 2. Subject description (if detail level is high enough) | |
| if detail_level >= 2: | |
| subject_prompt = "Describe the main subjects in this image with details about their appearance:" | |
| subject_desc = generate_basic_caption(image, subject_prompt) | |
| descriptions.append(("Main Subject(s)", subject_desc)) | |
| # 3. Setting/background | |
| if detail_level >= 3: | |
| setting_prompt = "Describe the setting, location, and background of this image:" | |
| setting_desc = generate_basic_caption(image, setting_prompt) | |
| descriptions.append(("Setting/Background", setting_desc)) | |
| # 4. Colors and visual elements | |
| if style_focus >= 3: | |
| color_prompt = "Describe the color scheme, visual composition, and artistic style of this image:" | |
| color_desc = generate_basic_caption(image, color_prompt) | |
| descriptions.append(("Visual Style/Colors", color_desc)) | |
| # 5. Emotion and mood | |
| if emotion_focus >= 3: | |
| emotion_prompt = "Describe the mood, emotional tone, and atmosphere conveyed in this image:" | |
| emotion_desc = generate_basic_caption(image, emotion_prompt) | |
| descriptions.append(("Mood/Emotional Tone", emotion_desc)) | |
| # 6. Lighting and time | |
| if detail_level >= 4 or style_focus >= 4: | |
| lighting_prompt = "Describe the lighting conditions and time of day in this image:" | |
| lighting_desc = generate_basic_caption(image, lighting_prompt) | |
| descriptions.append(("Lighting/Atmosphere", lighting_desc)) | |
| # 7. Details and textures (only for high detail levels) | |
| if detail_level >= 5: | |
| detail_prompt = "Describe the fine details, textures, and small elements visible in this image:" | |
| detail_desc = generate_basic_caption(image, detail_prompt) | |
| descriptions.append(("Fine Details/Textures", detail_desc)) | |
| # Format results | |
| formatted_result = "" | |
| # Add basic subject identification | |
| formatted_result += f"## Basic Caption:\n{basic_caption}\n\n" | |
| # Add comprehensive description section if more detailed | |
| if detail_level >= 2: | |
| formatted_result += f"## Detailed Description:\n\n" | |
| for title, desc in descriptions[1:]: # Skip the basic caption | |
| formatted_result += f"**{title}:** {desc}\n\n" | |
| # Additional section for AI generation prompts if requested | |
| if caption_type in ["AI Generation Prompt", "MidJourney", "Stable Diffusion"]: | |
| # Create a condensed version for AI generation | |
| ai_descriptions = [basic_caption.strip(".")] | |
| for _, desc in descriptions[1:]: | |
| if len(desc) > 10: | |
| ai_descriptions.append(desc.split(".")[0]) | |
| # Create specific prompt for AI image generation | |
| formatted_result += "## Suggested AI Image Generation Prompt:\n\n" | |
| ai_prompt = ", ".join(ai_descriptions) | |
| # Add qualifiers based on settings | |
| qualifiers = [] | |
| if detail_level >= 4: | |
| qualifiers.append("highly detailed") | |
| qualifiers.append("intricate") | |
| if emotion_focus >= 4: | |
| qualifiers.append("emotional") | |
| qualifiers.append("evocative") | |
| if style_focus >= 4: | |
| qualifiers.append("artistic composition") | |
| qualifiers.append("professional photography") | |
| if qualifiers: | |
| ai_prompt += ", " + ", ".join(qualifiers) | |
| formatted_result += ai_prompt | |
| return formatted_result | |
| except Exception as e: | |
| return f"Error generating description: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="Enhanced Image Captioning Studio") as demo: | |
| gr.HTML(TITLE) | |
| gr.Markdown("Upload an image to generate detailed captions and descriptions tailored to your needs.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Upload Image", type="pil") | |
| caption_type = gr.Dropdown( | |
| choices=list(CAPTION_TYPE_MAP.keys()), | |
| label="Caption Type", | |
| value="Descriptive", | |
| ) | |
| caption_length = gr.Dropdown( | |
| choices=["any", "very short", "short", "medium-length", "long", "very long"] + | |
| [str(i) for i in range(20, 301, 20)], | |
| label="Caption Length", | |
| value="medium-length", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| detail_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Detail Level") | |
| emotion_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Emotion Focus") | |
| style_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Style/Artistic Focus") | |
| extra_options = gr.CheckboxGroup( | |
| choices=[ | |
| "Include information about lighting.", | |
| "Include information about camera angle.", | |
| "Include information about whether there is a watermark or not.", | |
| "Include information about any artifacts or quality issues.", | |
| "If it is a photo, include likely camera details such as aperture, shutter speed, ISO, etc.", | |
| "Do NOT include anything sexual; keep it PG.", | |
| "Do NOT mention the image's resolution.", | |
| "Include information about the subjective aesthetic quality of the image.", | |
| "Include information on the image's composition style.", | |
| "Do NOT mention any text that is in the image.", | |
| "Specify the depth of field and focus.", | |
| "Mention the likely use of artificial or natural lighting sources.", | |
| "ONLY describe the most important elements of the image." | |
| ], | |
| label="Additional Options" | |
| ) | |
| custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override other settings)") | |
| gr.Markdown("**Note:** Custom prompts may not work with all models and settings.") | |
| generate_btn = gr.Button("Generate Description", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox(label="Generated Description", lines=25) | |
| # Set up event handlers | |
| generate_btn.click( | |
| fn=generate_advanced_description, | |
| inputs=[ | |
| input_image, | |
| caption_type, | |
| caption_length, | |
| detail_slider, | |
| emotion_slider, | |
| style_slider, | |
| extra_options, | |
| custom_prompt | |
| ], | |
| outputs=output_text | |
| ) | |
| gr.Markdown(""" | |
| ## How to Use | |
| 1. Upload an image | |
| 2. Select the type of caption you want | |
| 3. Choose a length preference | |
| 4. Adjust advanced settings if needed: | |
| - Detail Level: Controls the comprehensiveness of the description | |
| - Emotion Focus: Emphasizes mood and feelings in the output | |
| - Style Focus: Emphasizes artistic elements in the output | |
| 5. Select any additional options you'd like included | |
| 6. Click "Generate Description" | |
| ## About | |
| This application combines multiple image analysis techniques to generate rich, | |
| detailed descriptions of images. It's especially useful for creating prompts | |
| for AI image generators like Stable Diffusion, Midjourney, or DALL-E. | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |