Manasa1 commited on
Commit
c13a478
1 Parent(s): 514dd74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -153
app.py CHANGED
@@ -1,212 +1,246 @@
1
  import gradio as gr
2
  import torch
 
 
3
  from transformers import AutoTokenizer
4
  from TTS.api import TTS
5
- import numpy as np
6
- from PIL import Image
7
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
8
  from torchvision.io import write_video
9
  import os
10
  import groq
11
  import logging
12
  from pathlib import Path
 
 
13
 
14
  # Set up logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- # Environment setup and validation
19
- if not (API_KEY := os.getenv("GROQ_API_KEY")):
20
- raise ValueError("GROQ_API_KEY not found in environment variables")
21
-
22
- # Initialize clients and models with error handling
23
- try:
24
- groq_client = groq.Groq(api_key=API_KEY)
25
-
26
- # Initialize TTS model
27
- tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
28
-
29
- # Initialize Stable Diffusion with optimizations
30
- pipe = StableDiffusionPipeline.from_pretrained(
31
- "CompVis/stable-diffusion-v1-4",
32
- torch_dtype=torch.float32
33
- )
34
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
35
- pipe = pipe.to("cpu")
36
- pipe.enable_attention_slicing() # Memory optimization
37
- except Exception as e:
38
- logger.error(f"Error initializing models: {str(e)}")
39
- raise
40
-
41
- class ContentGenerator:
42
  def __init__(self):
43
  self.output_dir = Path("generated_content")
44
  self.output_dir.mkdir(exist_ok=True)
45
 
46
- def cleanup_old_files(self):
47
- """Clean up previously generated files"""
48
- for file in self.output_dir.glob("*"):
49
- try:
50
- file.unlink()
51
- except Exception as e:
52
- logger.warning(f"Could not delete {file}: {e}")
53
-
54
- def generate_text_with_groq(self, prompt, max_tokens=200):
55
- """Generate text with error handling"""
56
- try:
57
- chat_completion = groq_client.chat.completions.create(
58
- messages=[
59
- {
60
- "role": "system",
61
- "content": "You are a professional comedy writer skilled in creating short, witty scripts."
62
- },
63
- {
64
- "role": "user",
65
- "content": prompt
66
- }
67
- ],
68
- model="mixtral-8x7b-32768",
69
- max_tokens=max_tokens,
70
- temperature=0.7,
71
- )
72
- return chat_completion.choices[0].message.content
73
- except Exception as e:
74
- logger.error(f"Error generating text: {str(e)}")
75
- raise
 
 
 
 
 
76
 
77
- def generate_speech(self, text):
78
- """Generate speech with unique filenames"""
79
- try:
80
- output_path = self.output_dir / f"speech_{hash(text)}.wav"
81
- tts.tts_to_file(text=text, file_path=str(output_path))
82
- return str(output_path)
83
- except Exception as e:
84
- logger.error(f"Error generating speech: {str(e)}")
85
- raise
 
 
 
 
 
 
 
86
 
87
- def generate_video_frames(self, prompt, num_frames=15):
88
- """Generate video frames with progress tracking"""
89
  frames = []
90
- try:
91
- for i in range(num_frames):
92
- frame_prompt = f"{prompt}, frame {i+1} of {num_frames}"
93
- with torch.no_grad():
94
- image = pipe(
95
- frame_prompt,
96
- num_inference_steps=20,
97
- guidance_scale=7.5
98
- ).images[0]
99
- frames.append(np.array(image))
100
- logger.info(f"Generated frame {i+1}/{num_frames}")
101
- except Exception as e:
102
- logger.error(f"Error generating frames: {str(e)}")
103
- raise
 
 
 
 
104
  return frames
105
 
106
- def create_video_from_frames(self, frames, prompt):
107
- """Create video with unique filenames"""
108
  try:
109
- output_path = self.output_dir / f"video_{hash(prompt)}.mp4"
110
- frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2)
111
- write_video(str(output_path), frames_tensor, fps=8)
112
- return str(output_path)
 
 
 
 
 
 
 
 
 
113
  except Exception as e:
114
- logger.error(f"Error creating video: {str(e)}")
115
- raise
116
 
117
  def generate_comedy_animation(self, prompt):
118
- """Generate comedy animation with error handling"""
119
  try:
120
- self.cleanup_old_files()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- script_prompt = f"""Write a short, witty comedy script with two characters about {prompt}.
123
- Use the format 'Character: Dialogue or Action' for each line.
124
- Include clever wordplay, unexpected twists, and snappy dialogue.
125
- Keep it concise, around 5-8 exchanges. Make it genuinely funny!"""
126
 
127
- script = self.generate_text_with_groq(script_prompt)
128
- video_prompt = f"A comedic scene with two characters: {prompt}"
129
- frames = self.generate_video_frames(video_prompt)
130
- video_path = self.create_video_from_frames(frames, video_prompt)
131
- speech_path = self.generate_speech(script)
 
 
 
 
 
 
132
 
133
- return script, video_path, speech_path
134
  except Exception as e:
135
  logger.error(f"Error in comedy animation generation: {str(e)}")
136
  return "Error generating content", None, None
137
 
138
  def generate_kids_music_animation(self, theme):
139
- """Generate kids music animation with error handling"""
140
  try:
141
- self.cleanup_old_files()
 
 
 
 
 
 
 
142
 
143
- lyrics_prompt = f"""Write short, catchy, and simple lyrics for a children's song about {theme}.
144
- Each line should be on a new line. Don't include 'Verse' or 'Chorus' labels.
145
- Make it educational, fun, and easy to remember. Include a repeating chorus."""
 
 
 
 
 
 
 
 
146
 
147
- lyrics = self.generate_text_with_groq(lyrics_prompt)
148
- video_prompt = f"A colorful, animated music video for children about {theme}"
149
- frames = self.generate_video_frames(video_prompt)
150
- video_path = self.create_video_from_frames(frames, video_prompt)
151
- speech_path = self.generate_speech(lyrics)
 
 
 
 
 
 
152
 
153
- return lyrics, video_path, speech_path
154
  except Exception as e:
155
  logger.error(f"Error in kids music animation generation: {str(e)}")
156
  return "Error generating content", None, None
157
 
158
- # Initialize content generator
159
- generator = ContentGenerator()
160
-
161
  # Gradio Interface
162
- with gr.Blocks(theme='ysharma/steampunk') as app:
163
- gr.Markdown("## AI-Generated Video and Audio Content")
164
-
165
- # Status message for errors
166
- status_msg = gr.Textbox(label="Status", visible=False)
167
 
168
- with gr.Tab("Comedy Animation"):
169
- comedy_prompt = gr.Textbox(label="Enter comedy prompt")
170
- comedy_generate_btn = gr.Button("Generate Comedy Animation")
171
- comedy_script = gr.Textbox(label="Generated Comedy Script")
172
- comedy_animation = gr.Video(label="Comedy Animation")
173
- comedy_audio = gr.Audio(label="Comedy Speech")
174
-
175
- def comedy_wrapper(prompt):
176
- status_msg.visible = True
177
- try:
178
- return generator.generate_comedy_animation(prompt)
179
- except Exception as e:
180
- status_msg.value = f"Error: {str(e)}"
181
- return None, None, None
 
 
 
 
 
 
 
 
 
182
 
 
183
  comedy_generate_btn.click(
184
- comedy_wrapper,
185
  inputs=comedy_prompt,
186
  outputs=[comedy_script, comedy_animation, comedy_audio]
187
  )
188
-
189
- with gr.Tab("Kids Music Animation"):
190
- music_theme = gr.Textbox(label="Enter music theme for kids")
191
- music_generate_btn = gr.Button("Generate Kids Music Animation")
192
- music_lyrics = gr.Textbox(label="Generated Lyrics")
193
- music_animation = gr.Video(label="Music Animation")
194
- music_audio = gr.Audio(label="Music Audio")
195
-
196
- def music_wrapper(theme):
197
- status_msg.visible = True
198
- try:
199
- return generator.generate_kids_music_animation(theme)
200
- except Exception as e:
201
- status_msg.value = f"Error: {str(e)}"
202
- return None, None, None
203
-
204
  music_generate_btn.click(
205
- music_wrapper,
206
  inputs=music_theme,
207
  outputs=[music_lyrics, music_animation, music_audio]
208
  )
209
 
 
 
210
  if __name__ == "__main__":
 
211
  app.launch()
212
-
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
  from transformers import AutoTokenizer
6
  from TTS.api import TTS
 
 
7
  from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
8
  from torchvision.io import write_video
9
  import os
10
  import groq
11
  import logging
12
  from pathlib import Path
13
+ import cv2
14
+ from moviepy.editor import VideoFileClip, AudioFileClip, CompositeVideoClip
15
 
16
  # Set up logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
19
 
20
+ class EnhancedContentGenerator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def __init__(self):
22
  self.output_dir = Path("generated_content")
23
  self.output_dir.mkdir(exist_ok=True)
24
 
25
+ # Initialize TTS with a more cartoon-appropriate voice
26
+ self.tts = TTS(model_name="tts_models/en/vctk/vits")
27
+
28
+ # Initialize Stable Diffusion with cartoon-specific model
29
+ self.pipe = StableDiffusionPipeline.from_pretrained(
30
+ "nitrosocke/Ghibli-Diffusion", # Using anime/cartoon style model
31
+ torch_dtype=torch.float32
32
+ )
33
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
34
+ self.pipe = self.pipe.to("cpu")
35
+ self.pipe.enable_attention_slicing()
36
+
37
+ # Initialize Groq client
38
+ if not (self.api_key := os.getenv("GROQ_API_KEY")):
39
+ raise ValueError("GROQ_API_KEY not found in environment variables")
40
+ self.groq_client = groq.Groq(api_key=self.api_key)
41
+
42
+ def generate_cartoon_frame(self, prompt, style="cartoon"):
43
+ """Generate a single cartoon frame with specified style"""
44
+ style_prompts = {
45
+ "cartoon": "in the style of a western cartoon, vibrant colors, simple shapes",
46
+ "anime": "in the style of Studio Ghibli anime, detailed backgrounds",
47
+ "kids": "in the style of a children's book illustration, cute and colorful"
48
+ }
49
+
50
+ enhanced_prompt = f"{prompt}, {style_prompts.get(style, style_prompts['cartoon'])}"
51
+
52
+ with torch.no_grad():
53
+ image = self.pipe(
54
+ enhanced_prompt,
55
+ num_inference_steps=30,
56
+ guidance_scale=7.5
57
+ ).images[0]
58
+
59
+ return np.array(image)
60
 
61
+ def add_cartoon_effects(self, frame):
62
+ """Add cartoon-style effects to a frame"""
63
+ # Convert to RGB if necessary
64
+ if len(frame.shape) == 2:
65
+ frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
66
+
67
+ # Apply cartoon effect
68
+ gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
69
+ gray = cv2.medianBlur(gray, 5)
70
+ edges = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 9)
71
+ color = cv2.bilateralFilter(frame, 9, 300, 300)
72
+
73
+ # Combine edges with color
74
+ cartoon = cv2.bitwise_and(color, color, mask=edges)
75
+
76
+ return cartoon
77
 
78
+ def generate_video_sequence(self, script, style="cartoon", num_frames=24):
79
+ """Generate a sequence of frames based on the script"""
80
  frames = []
81
+ scenes = script.split('\n\n') # Split script into scenes
82
+
83
+ frames_per_scene = max(num_frames // len(scenes), 4)
84
+
85
+ for scene in scenes:
86
+ if not scene.strip():
87
+ continue
88
+
89
+ # Generate base frame for the scene
90
+ scene_prompt = f"cartoon scene showing: {scene}"
91
+ base_frame = self.generate_cartoon_frame(scene_prompt, style)
92
+
93
+ # Generate slight variations for animation
94
+ for i in range(frames_per_scene):
95
+ frame = base_frame.copy()
96
+ frame = self.add_cartoon_effects(frame)
97
+ frames.append(frame)
98
+
99
  return frames
100
 
101
+ def enhance_audio(self, audio_path, style="cartoon"):
102
+ """Add effects to the audio based on style"""
103
  try:
104
+ audio = AudioFileClip(audio_path)
105
+
106
+ if style == "cartoon":
107
+ # Speed up slightly for cartoon effect
108
+ audio = audio.speedx(1.1)
109
+ elif style == "kids":
110
+ # Add echo effect for kids music
111
+ echo = audio.set_start(0.1)
112
+ audio = CompositeVideoClip([audio, echo.volumex(0.3)])
113
+
114
+ enhanced_path = audio_path.replace('.wav', '_enhanced.wav')
115
+ audio.write_audiofile(enhanced_path)
116
+ return enhanced_path
117
  except Exception as e:
118
+ logger.error(f"Error enhancing audio: {str(e)}")
119
+ return audio_path
120
 
121
  def generate_comedy_animation(self, prompt):
122
+ """Generate enhanced comedy animation"""
123
  try:
124
+ # Generate a more structured comedy script
125
+ script_prompt = f"""Write a funny cartoon script about {prompt}.
126
+ Include:
127
+ - Two distinct character voices
128
+ - Physical comedy moments
129
+ - Sound effects in [brackets]
130
+ - Scene descriptions in (parentheses)
131
+ Keep it family-friendly and around 3-4 scenes."""
132
+
133
+ script = self.groq_client.chat.completions.create(
134
+ messages=[
135
+ {"role": "system", "content": "You are a professional cartoon comedy writer."},
136
+ {"role": "user", "content": script_prompt}
137
+ ],
138
+ model="mixtral-8x7b-32768",
139
+ temperature=0.7
140
+ ).choices[0].message.content
141
 
142
+ # Generate frames with cartoon style
143
+ frames = self.generate_video_sequence(script, style="cartoon")
 
 
144
 
145
+ # Generate and enhance audio
146
+ speech_path = str(self.output_dir / f"speech_{hash(script)}.wav")
147
+ self.tts.tts_to_file(text=script, file_path=speech_path)
148
+ enhanced_speech = self.enhance_audio(speech_path, "cartoon")
149
+
150
+ # Create video with enhanced frames
151
+ video_path = str(self.output_dir / f"video_{hash(prompt)}.mp4")
152
+ frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2)
153
+ write_video(video_path, frames_tensor, fps=12) # Higher FPS for smoother animation
154
+
155
+ return script, video_path, enhanced_speech
156
 
 
157
  except Exception as e:
158
  logger.error(f"Error in comedy animation generation: {str(e)}")
159
  return "Error generating content", None, None
160
 
161
  def generate_kids_music_animation(self, theme):
162
+ """Generate enhanced kids music animation"""
163
  try:
164
+ # Generate kid-friendly lyrics with music directions
165
+ lyrics_prompt = f"""Write lyrics for a children's educational song about {theme}.
166
+ Include:
167
+ - Simple, repetitive chorus
168
+ - Educational facts
169
+ - [Music notes] for melody changes
170
+ - (Action descriptions) for animation
171
+ Make it upbeat and memorable!"""
172
 
173
+ lyrics = self.groq_client.chat.completions.create(
174
+ messages=[
175
+ {"role": "system", "content": "You are a children's music composer."},
176
+ {"role": "user", "content": lyrics_prompt}
177
+ ],
178
+ model="mixtral-8x7b-32768",
179
+ temperature=0.7
180
+ ).choices[0].message.content
181
+
182
+ # Generate frames with kids' style
183
+ frames = self.generate_video_sequence(lyrics, style="kids", num_frames=36)
184
 
185
+ # Generate and enhance audio
186
+ speech_path = str(self.output_dir / f"music_{hash(lyrics)}.wav")
187
+ self.tts.tts_to_file(text=lyrics, file_path=speech_path)
188
+ enhanced_speech = self.enhance_audio(speech_path, "kids")
189
+
190
+ # Create video with enhanced frames
191
+ video_path = str(self.output_dir / f"video_{hash(theme)}.mp4")
192
+ frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2)
193
+ write_video(video_path, frames_tensor, fps=15) # Smooth animation for kids
194
+
195
+ return lyrics, video_path, enhanced_speech
196
 
 
197
  except Exception as e:
198
  logger.error(f"Error in kids music animation generation: {str(e)}")
199
  return "Error generating content", None, None
200
 
 
 
 
201
  # Gradio Interface
202
+ def create_interface():
203
+ generator = EnhancedContentGenerator()
 
 
 
204
 
205
+ with gr.Blocks(theme='ysharma/steampunk') as app:
206
+ gr.Markdown("# AI Cartoon Generator")
207
+ gr.Markdown("Generate cartoon comedy clips and kids music videos!")
208
+
209
+ with gr.Tab("Cartoon Comedy"):
210
+ comedy_prompt = gr.Textbox(
211
+ label="What should the cartoon be about?",
212
+ placeholder="E.g., 'a penguin learning to fly'"
213
+ )
214
+ comedy_generate_btn = gr.Button("Generate Cartoon Comedy", variant="primary")
215
+ comedy_script = gr.Textbox(label="Generated Script")
216
+ comedy_animation = gr.Video(label="Cartoon Animation")
217
+ comedy_audio = gr.Audio(label="Cartoon Audio")
218
+
219
+ with gr.Tab("Kids Music Video"):
220
+ music_theme = gr.Textbox(
221
+ label="What should the song teach about?",
222
+ placeholder="E.g., 'the water cycle'"
223
+ )
224
+ music_generate_btn = gr.Button("Generate Music Video", variant="primary")
225
+ music_lyrics = gr.Textbox(label="Song Lyrics")
226
+ music_animation = gr.Video(label="Music Video")
227
+ music_audio = gr.Audio(label="Song Audio")
228
 
229
+ # Event handlers
230
  comedy_generate_btn.click(
231
+ generator.generate_comedy_animation,
232
  inputs=comedy_prompt,
233
  outputs=[comedy_script, comedy_animation, comedy_audio]
234
  )
235
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  music_generate_btn.click(
237
+ generator.generate_kids_music_animation,
238
  inputs=music_theme,
239
  outputs=[music_lyrics, music_animation, music_audio]
240
  )
241
 
242
+ return app
243
+
244
  if __name__ == "__main__":
245
+ app = create_interface()
246
  app.launch()