from google import genai from google.genai import types from typing import Union, List, Generator, Dict, Optional from PIL import Image from io import BytesIO import base64 import requests import asyncio import os from dotenv import load_dotenv from .category_instructions import get_instruction_for_category from .category_config import CATEGORY_CONFIGS load_dotenv() client = genai.Client( api_key=os.getenv("API_KEY") ) def bytes_to_base64(data: bytes, with_prefix: bool = True) -> str: encoded = base64.b64encode(data).decode("utf-8") return f"data:image/png;base64,{encoded}" if with_prefix else encoded def decode_base64_image(base64_str: str) -> Image.Image: # Remove the prefix if present (e.g., "data:image/png;base64,") if base64_str.startswith("data:image"): base64_str = base64_str.split(",")[1] image_data = base64.b64decode(base64_str) image = Image.open(BytesIO(image_data)) return image async def async_generate_text_and_image(prompt, category: Optional[str] = None): # Get the appropriate instruction and configuration instruction = get_instruction_for_category(category) config = CATEGORY_CONFIGS.get(category.lower() if category else "", {}) # Enhance the prompt with category-specific guidance if available if config: style_guide = config.get("style_guide", "") conventions = config.get("conventions", []) common_elements = config.get("common_elements", []) enhanced_prompt = ( f"{instruction}\n\n" f"Style Guide: {style_guide}\n" f"Drawing Conventions to Follow:\n- " + "\n- ".join(conventions) + "\n" f"Consider Including These Elements:\n- " + "\n- ".join(common_elements) + "\n\n" f"User Request: {prompt}" ) else: enhanced_prompt = f"{instruction}\n\nUser Request: {prompt}" response = await client.aio.models.generate_content( model=os.getenv("MODEL"), contents=enhanced_prompt, config=types.GenerateContentConfig( response_modalities=['TEXT', 'IMAGE'] ) ) for part in response.candidates[0].content.parts: if hasattr(part, 'text') and part.text is not None: # Try to parse the text into sections try: text_sections = {} current_section = "overview" lines = part.text.split('\n') for line in lines: line = line.strip() if not line: continue # Check for section headers if any(line.lower().startswith(f"{i}.") for i in range(1, 6)): section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower() section_name = section_name.replace(' ', '_') current_section = section_name text_sections[current_section] = [] else: if current_section not in text_sections: text_sections[current_section] = [] text_sections[current_section].append(line) # Clean up the sections for section in text_sections: text_sections[section] = '\n'.join(text_sections[section]).strip() yield {'type': 'text', 'data': text_sections} except Exception as e: # Fallback to raw text if parsing fails yield {'type': 'text', 'data': {'raw_text': part.text}} elif hasattr(part, 'inline_data') and part.inline_data is not None: yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)} async def async_generate_with_image_input(text: Optional[str], image_path: str, category: Optional[str] = None): # Validate that the image input is a base64 data URI if not isinstance(image_path, str) or not image_path.startswith("data:image/"): raise ValueError("Invalid image input: expected a base64 Data URI starting with 'data:image/'") # Decode the base64 string into a PIL Image image = decode_base64_image(image_path) # Get the appropriate instruction for the category instruction = get_instruction_for_category(category) contents = [] if text: # Combine the instruction with the user's text input combined_text = f"{instruction}\n\nUser Request: {text}" contents.append(combined_text) else: contents.append(instruction) contents.append(image) response = await client.aio.models.generate_content( model=os.getenv("MODEL"), contents=contents, config=types.GenerateContentConfig( response_modalities=['TEXT', 'IMAGE'] ) ) for part in response.candidates[0].content.parts: if hasattr(part, 'text') and part.text is not None: # Try to parse the text into sections try: text_sections = {} current_section = "overview" lines = part.text.split('\n') for line in lines: line = line.strip() if not line: continue # Check for section headers if any(line.lower().startswith(f"{i}.") for i in range(1, 6)): section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower() section_name = section_name.replace(' ', '_') current_section = section_name text_sections[current_section] = [] else: if current_section not in text_sections: text_sections[current_section] = [] text_sections[current_section].append(line) # Clean up the sections for section in text_sections: text_sections[section] = '\n'.join(text_sections[section]).strip() yield {'type': 'text', 'data': text_sections} except Exception as e: # Fallback to raw text if parsing fails yield {'type': 'text', 'data': {'raw_text': part.text}} elif hasattr(part, 'inline_data') and part.inline_data is not None: yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)}