Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image, ImageDraw | |
from transformers import AutoTokenizer | |
from TTS.api import TTS | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from torchvision.io import write_video | |
import os | |
import groq | |
import logging | |
from pathlib import Path | |
import cv2 | |
from moviepy.editor import VideoFileClip, AudioFileClip, CompositeVideoClip | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class EnhancedContentGenerator: | |
def __init__(self): | |
# Check for API key | |
self.api_key = os.getenv("GROQ_API_KEY") | |
if not self.api_key: | |
raise ValueError("GROQ_API_KEY not found in environment variables") | |
self.output_dir = Path("generated_content") | |
self.output_dir.mkdir(exist_ok=True) | |
# Initialize TTS with a more cartoon-appropriate voice | |
self.tts = TTS(model_name="tts_models/en/vctk/vits") | |
# Initialize Stable Diffusion with cartoon-specific model | |
self.pipe = StableDiffusionPipeline.from_pretrained( | |
"nitrosocke/Ghibli-Diffusion", # Using anime/cartoon style model | |
torch_dtype=torch.float32 | |
) | |
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) | |
self.pipe = self.pipe.to("cpu") | |
self.pipe.enable_attention_slicing() | |
# Initialize Groq client | |
self.groq_client = groq.Groq(api_key=self.api_key) | |
# Create output directories if they don't exist | |
self.audio_dir = self.output_dir / "audio" | |
self.video_dir = self.output_dir / "video" | |
self.audio_dir.mkdir(exist_ok=True) | |
self.video_dir.mkdir(exist_ok=True) | |
def generate_cartoon_frame(self, prompt, style="cartoon"): | |
"""Generate a single cartoon frame with specified style""" | |
style_prompts = { | |
"cartoon": "in the style of a western cartoon, vibrant colors, simple shapes", | |
"anime": "in the style of Studio Ghibli anime, detailed backgrounds", | |
"kids": "in the style of a children's book illustration, cute and colorful" | |
} | |
enhanced_prompt = f"{prompt}, {style_prompts.get(style, style_prompts['cartoon'])}" | |
with torch.no_grad(): | |
image = self.pipe( | |
enhanced_prompt, | |
num_inference_steps=30, | |
guidance_scale=7.5 | |
).images[0] | |
return np.array(image) | |
def add_cartoon_effects(self, frame): | |
"""Add cartoon-style effects to a frame""" | |
# Convert to RGB if necessary | |
if len(frame.shape) == 2: | |
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) | |
# Apply cartoon effect | |
gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) | |
gray = cv2.medianBlur(gray, 5) | |
edges = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 9) | |
color = cv2.bilateralFilter(frame, 9, 300, 300) | |
# Combine edges with color | |
cartoon = cv2.bitwise_and(color, color, mask=edges) | |
return cartoon | |
def generate_video_sequence(self, script, style="cartoon", num_frames=24): | |
"""Generate a sequence of frames based on the script""" | |
frames = [] | |
scenes = script.split('\n\n') # Split script into scenes | |
frames_per_scene = max(num_frames // len(scenes), 4) | |
for scene in scenes: | |
if not scene.strip(): | |
continue | |
# Generate base frame for the scene | |
scene_prompt = f"cartoon scene showing: {scene}" | |
base_frame = self.generate_cartoon_frame(scene_prompt, style) | |
# Generate slight variations for animation | |
for i in range(frames_per_scene): | |
frame = base_frame.copy() | |
frame = self.add_cartoon_effects(frame) | |
frames.append(frame) | |
return frames | |
def enhance_audio(self, audio_path, style="cartoon"): | |
"""Add effects to the audio based on style""" | |
try: | |
audio = AudioFileClip(str(audio_path)) | |
if style == "cartoon": | |
# Speed up slightly for cartoon effect | |
audio = audio.speedx(1.1) | |
elif style == "kids": | |
# Add echo effect for kids music | |
echo = audio.set_start(0.1) | |
audio = CompositeVideoClip([audio, echo.volumex(0.3)]) | |
enhanced_path = str(audio_path).replace('.wav', '_enhanced.wav') | |
audio.write_audiofile(enhanced_path) | |
return enhanced_path | |
except Exception as e: | |
logger.error(f"Error enhancing audio: {str(e)}") | |
return str(audio_path) | |
def generate_comedy_animation(self, prompt): | |
"""Generate enhanced comedy animation""" | |
try: | |
# Generate a more structured comedy script | |
script_prompt = f"""Write a funny cartoon script about {prompt}. | |
Include: | |
- Two distinct character voices | |
- Physical comedy moments | |
- Sound effects in [brackets] | |
- Scene descriptions in (parentheses) | |
Keep it family-friendly and around 3-4 scenes.""" | |
completion = self.groq_client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a professional cartoon comedy writer."}, | |
{"role": "user", "content": script_prompt} | |
], | |
model="mixtral-8x7b-32768", | |
temperature=0.7 | |
) | |
script = completion.choices[0].message.content | |
# Generate frames with cartoon style | |
frames = self.generate_video_sequence(script, style="cartoon") | |
# Generate and enhance audio | |
speech_path = self.audio_dir / f"speech_{hash(script)}.wav" | |
self.tts.tts_to_file(text=script, file_path=str(speech_path)) | |
enhanced_speech = self.enhance_audio(speech_path, "cartoon") | |
# Create video with enhanced frames | |
video_path = self.video_dir / f"video_{hash(prompt)}.mp4" | |
frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2) | |
write_video(str(video_path), frames_tensor, fps=12) # Higher FPS for smoother animation | |
return script, str(video_path), enhanced_speech | |
except Exception as e: | |
logger.error(f"Error in comedy animation generation: {str(e)}") | |
return "Error generating content", None, None | |
def generate_kids_music_animation(self, theme): | |
"""Generate enhanced kids music animation""" | |
try: | |
# Generate kid-friendly lyrics with music directions | |
lyrics_prompt = f"""Write lyrics for a children's educational song about {theme}. | |
Include: | |
- Simple, repetitive chorus | |
- Educational facts | |
- [Music notes] for melody changes | |
- (Action descriptions) for animation | |
Make it upbeat and memorable!""" | |
completion = self.groq_client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a children's music composer."}, | |
{"role": "user", "content": lyrics_prompt} | |
], | |
model="mixtral-8x7b-32768", | |
temperature=0.7 | |
) | |
lyrics = completion.choices[0].message.content | |
# Generate frames with kids' style | |
frames = self.generate_video_sequence(lyrics, style="kids", num_frames=36) | |
# Generate and enhance audio | |
speech_path = self.audio_dir / f"music_{hash(lyrics)}.wav" | |
self.tts.tts_to_file(text=lyrics, file_path=str(speech_path)) | |
enhanced_speech = self.enhance_audio(speech_path, "kids") | |
# Create video with enhanced frames | |
video_path = self.video_dir / f"video_{hash(theme)}.mp4" | |
frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2) | |
write_video(str(video_path), frames_tensor, fps=15) # Smooth animation for kids | |
return lyrics, str(video_path), enhanced_speech | |
except Exception as e: | |
logger.error(f"Error in kids music animation generation: {str(e)}") | |
return "Error generating content", None, None | |
# Gradio Interface | |
def create_interface(): | |
generator = EnhancedContentGenerator() | |
with gr.Blocks(theme='ysharma/steampunk') as app: | |
gr.Markdown("# AI Cartoon Generator") | |
gr.Markdown("Generate cartoon comedy clips and kids music videos!") | |
with gr.Tab("Cartoon Comedy"): | |
comedy_prompt = gr.Textbox( | |
label="What should the cartoon be about?", | |
placeholder="E.g., 'a penguin learning to fly'" | |
) | |
comedy_generate_btn = gr.Button("Generate Cartoon Comedy", variant="primary") | |
comedy_script = gr.Textbox(label="Generated Script") | |
comedy_animation = gr.Video(label="Cartoon Animation") | |
comedy_audio = gr.Audio(label="Cartoon Audio") | |
with gr.Tab("Kids Music Video"): | |
music_theme = gr.Textbox( | |
label="What should the song teach about?", | |
placeholder="E.g., 'the water cycle'" | |
) | |
music_generate_btn = gr.Button("Generate Music Video", variant="primary") | |
music_lyrics = gr.Textbox(label="Song Lyrics") | |
music_animation = gr.Video(label="Music Video") | |
music_audio = gr.Audio(label="Song Audio") | |
# Event handlers | |
comedy_generate_btn.click( | |
generator.generate_comedy_animation, | |
inputs=comedy_prompt, | |
outputs=[comedy_script, comedy_animation, comedy_audio] | |
) | |
music_generate_btn.click( | |
generator.generate_kids_music_animation, | |
inputs=music_theme, | |
outputs=[music_lyrics, music_animation, music_audio] | |
) | |
return app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch() |