|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from diffusers import StableDiffusionPipeline |
|
|
|
|
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") |
|
comic_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
|
|
model_id = "stabilityai/sd-turbo" |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id) |
|
pipe.to("cpu") |
|
|
|
|
|
def generate_comic(user_prompt, num_panels, art_choice): |
|
|
|
instruction = f"Generate a {num_panels}-panel comic strip description for the topic: {user_prompt}" |
|
response = comic_pipeline(instruction, max_new_tokens=400, temperature=0.7)[0]['generated_text'] |
|
comic_panels = [line.strip() for line in response.split("\n") if line.strip()][:num_panels] |
|
|
|
|
|
comic_images = [] |
|
for panel in comic_panels: |
|
prompt = f"{panel}, {art_choice} style, bold outlines, vibrant colors" |
|
image = pipe(prompt, num_inference_steps=30, do_sample=True, temperature=0.7).images[0] |
|
comic_images.append(image) |
|
|
|
|
|
panel_width, panel_height = comic_images[0].size |
|
rows, cols = (1, len(comic_images)) if len(comic_images) <= 3 else (2, 3) |
|
comic_strip = Image.new("RGB", (panel_width * cols, panel_height * rows)) |
|
|
|
for i, img in enumerate(comic_images): |
|
x_offset = (i % cols) * panel_width |
|
y_offset = (i // cols) * panel_height |
|
comic_strip.paste(img, (x_offset, y_offset)) |
|
|
|
return comic_strip |
|
|
|
|
|
art_styles = ["Classic Comic", "Anime", "Cartoon", "Noir", "Cyberpunk", "Watercolor"] |
|
interface = gr.Interface( |
|
fn=generate_comic, |
|
inputs=[ |
|
gr.Textbox(label="Enter Comic Topic", placeholder="e.g., Iron Man vs Hulk"), |
|
gr.Slider(minimum=3, maximum=6, step=1, label="Number of Panels"), |
|
gr.Dropdown(choices=art_styles, label="Choose Art Style") |
|
], |
|
outputs="image", |
|
title="Comic Strip Generator", |
|
description="Generate your own comic strip by entering a topic, choosing the number of panels, and selecting an art style." |
|
) |
|
|
|
interface.launch() |
|
|