| import google.generativeai as genai |
| import torch |
| from diffusers import StableDiffusionPipeline |
| from PIL import Image |
| import gradio as gr |
|
|
| |
| GOOGLE_API_KEY = "AIzaSyBn9Ehq5oIqkEov_fmAMT258X6imfNXfvg" |
| genai.configure(api_key=GOOGLE_API_KEY) |
| model = genai.GenerativeModel("models/gemini-2.0-flash") |
|
|
| |
| model_id = "prompthero/openjourney-v4" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"✅ Using device: {device}") |
|
|
| try: |
| pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) |
| pipe.to(device) |
| print("✅ Stable Diffusion Model Loaded Successfully!") |
| except Exception as e: |
| print(f"❌ Error loading Stable Diffusion model: {e}") |
| exit() |
|
|
| |
| art_styles = { |
| "Cinematic": "cinematic, highly detailed, vivid colors, realistic style", |
| "Anime": "anime style, vibrant colors, expressive characters", |
| "Watercolor": "watercolor painting, soft colors, artistic style", |
| "Cyberpunk": "cyberpunk, neon lights, futuristic, dystopian", |
| } |
|
|
|
|
| def get_summary_or_concept(prompt, num_panels): |
| content_prompt = ( |
| f"Summarize the key concepts or moments of '{prompt}' into {num_panels} brief points, " |
| "with 1 concise sentence for each panel, equally distributing the information." |
| ) |
|
|
| try: |
| response = model.generate_content(content_prompt) |
| if response and response.text: |
| text_response = response.text.strip() |
| panels = [p.strip() for p in text_response.split("\n") if p.strip()] |
| return panels[:num_panels] if len(panels) > 1 else [text_response] * num_panels |
| else: |
| return [] |
| except Exception as e: |
| print(f"❌ Error generating content: {e}") |
| return [] |
|
|
|
|
| def generate_comic_images(story_prompts, style): |
| images = [] |
| for i, panel in enumerate(story_prompts): |
| print(f"🎨 Generating image for Panel {i+1}/{len(story_prompts)}: {panel.strip()[:50]}...") |
|
|
| image_prompt = ( |
| f"Highly detailed, {art_styles[style]} scene depicting: {panel.strip()}, " |
| "no text, no words" |
| ) |
|
|
| try: |
| image = pipe( |
| image_prompt, |
| num_inference_steps=20, |
| guidance_scale=7.0, |
| negative_prompt="text, words, letters, nsfw, explicit, unsafe" |
| ).images[0] |
| images.append(image) |
| except Exception as e: |
| print(f"❌ Error generating image: {e}") |
| return None |
| return images |
|
|
|
|
| def generate_comic(topic, num_panels, style): |
| """Generates story and images""" |
| story = get_summary_or_concept(topic, num_panels) |
| if not story: |
| return "⚠️ Error generating story.", None |
|
|
| images = generate_comic_images(story, style) |
| return story, images |
|
|
|
|
| def comic_interface(topic, num_panels, style): |
| """Generate and display comic""" |
| story, images = generate_comic(topic, num_panels, style) |
| if images: |
| return story, images |
| else: |
| return "⚠️ Error generating comic.", None |
|
|
|
|
| |
| interface = gr.Interface( |
| fn=comic_interface, |
| inputs=[ |
| gr.Textbox(label="Enter a topic for the comic strip", value="Government of India"), |
| gr.Slider(minimum=3, maximum=10, step=1, label="Number of Comic Panels", value=6), |
| gr.Radio(list(art_styles.keys()), label="Choose an Art Style", value="Cinematic"), |
| ], |
| outputs=[ |
| gr.Textbox(label="Generated Story Prompts"), |
| gr.Gallery(label="Generated Comic Strip"), |
| ], |
| title="🤖 ComicWala AI Generator", |
| description="Generate AI-based comic strips using Gemini and Stable Diffusion.", |
| ) |
|
|
| if __name__ == "__main__": |
| interface.launch() |
|
|