Manasa1 commited on
Commit
514dd74
·
verified ·
1 Parent(s): 6738411

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -78
app.py CHANGED
@@ -8,87 +8,162 @@ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
8
  from torchvision.io import write_video
9
  import os
10
  import groq
 
 
11
 
12
- # Initialize Groq client
13
- groq_client = groq.Groq()
14
- API_KEY = os.getenv("GROQ_API_KEY")
15
- groq_client.api_key = API_KEY
16
-
17
- # Initialize TTS model
18
- tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
19
-
20
- # Initialize Stable Diffusion pipeline for CPU
21
- pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
22
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
23
- pipe = pipe.to("cpu")
24
-
25
- def generate_text_with_groq(prompt, max_tokens=200):
26
- chat_completion = groq_client.chat.completions.create(
27
- messages=[
28
- {
29
- "role": "system",
30
- "content": "You are a professional comedy writer skilled in creating short, witty scripts."
31
- },
32
- {
33
- "role": "user",
34
- "content": prompt
35
- }
36
- ],
37
- model="mixtral-8x7b-32768",
38
- max_tokens=max_tokens,
39
- temperature=0.7,
40
- )
41
- return chat_completion.choices[0].message.content
42
-
43
- def generate_speech(text):
44
- output_path = "generated_speech.wav"
45
- tts.tts_to_file(text=text, file_path=output_path)
46
- return output_path
47
-
48
- def generate_video_frames(prompt, num_frames=10):
49
- frames = []
50
- for i in range(num_frames):
51
- frame_prompt = f"{prompt}, frame {i+1} of {num_frames}"
52
- with torch.no_grad():
53
- image = pipe(frame_prompt, num_inference_steps=20).images[0]
54
- frames.append(np.array(image))
55
- return frames
56
-
57
- def create_video_from_frames(frames, output_path="output_video.mp4", fps=5):
58
- frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2)
59
- write_video(output_path, frames_tensor, fps=fps)
60
- return output_path
61
-
62
- def generate_comedy_animation(prompt):
63
- script_prompt = f"""Write a short, witty comedy script with two characters about {prompt}.
64
- Use the format 'Character: Dialogue or Action' for each line.
65
- Include clever wordplay, unexpected twists, and snappy dialogue.
66
- Keep it concise, around 5-8 exchanges. Make it genuinely funny!"""
67
 
68
- script = generate_text_with_groq(script_prompt)
69
- video_prompt = f"A comedic scene with two characters: {prompt}"
70
- frames = generate_video_frames(video_prompt)
71
- video_path = create_video_from_frames(frames)
72
- speech_path = generate_speech(script)
73
- return script, video_path, speech_path
74
-
75
- def generate_kids_music_animation(theme):
76
- lyrics_prompt = f"""Write short, catchy, and simple lyrics for a children's song about {theme}.
77
- Each line should be on a new line. Don't include 'Verse' or 'Chorus' labels.
78
- Make it educational, fun, and easy to remember. Include a repeating chorus."""
79
 
80
- lyrics = generate_text_with_groq(lyrics_prompt)
81
- video_prompt = f"A colorful, animated music video for children about {theme}"
82
- frames = generate_video_frames(video_prompt)
83
- video_path = create_video_from_frames(frames)
84
- speech_path = generate_speech(lyrics)
85
- return lyrics, video_path, speech_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- my_theme='ysharma/steampunk'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Gradio Interface
90
- with gr.Blocks(theme=my_theme) as app:
91
- gr.Markdown("## AI-Generated Video and Audio Content (Optimized CPU Version with Groq API)")
 
 
 
92
 
93
  with gr.Tab("Comedy Animation"):
94
  comedy_prompt = gr.Textbox(label="Enter comedy prompt")
@@ -97,8 +172,16 @@ with gr.Blocks(theme=my_theme) as app:
97
  comedy_animation = gr.Video(label="Comedy Animation")
98
  comedy_audio = gr.Audio(label="Comedy Speech")
99
 
 
 
 
 
 
 
 
 
100
  comedy_generate_btn.click(
101
- generate_comedy_animation,
102
  inputs=comedy_prompt,
103
  outputs=[comedy_script, comedy_animation, comedy_audio]
104
  )
@@ -110,11 +193,20 @@ with gr.Blocks(theme=my_theme) as app:
110
  music_animation = gr.Video(label="Music Animation")
111
  music_audio = gr.Audio(label="Music Audio")
112
 
 
 
 
 
 
 
 
 
113
  music_generate_btn.click(
114
- generate_kids_music_animation,
115
  inputs=music_theme,
116
  outputs=[music_lyrics, music_animation, music_audio]
117
  )
118
 
119
- app.launch()
 
120
 
 
8
  from torchvision.io import write_video
9
  import os
10
  import groq
11
+ import logging
12
+ from pathlib import Path
13
 
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Environment setup and validation
19
+ if not (API_KEY := os.getenv("GROQ_API_KEY")):
20
+ raise ValueError("GROQ_API_KEY not found in environment variables")
21
+
22
+ # Initialize clients and models with error handling
23
+ try:
24
+ groq_client = groq.Groq(api_key=API_KEY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Initialize TTS model
27
+ tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC")
 
 
 
 
 
 
 
 
 
28
 
29
+ # Initialize Stable Diffusion with optimizations
30
+ pipe = StableDiffusionPipeline.from_pretrained(
31
+ "CompVis/stable-diffusion-v1-4",
32
+ torch_dtype=torch.float32
33
+ )
34
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
35
+ pipe = pipe.to("cpu")
36
+ pipe.enable_attention_slicing() # Memory optimization
37
+ except Exception as e:
38
+ logger.error(f"Error initializing models: {str(e)}")
39
+ raise
40
+
41
+ class ContentGenerator:
42
+ def __init__(self):
43
+ self.output_dir = Path("generated_content")
44
+ self.output_dir.mkdir(exist_ok=True)
45
+
46
+ def cleanup_old_files(self):
47
+ """Clean up previously generated files"""
48
+ for file in self.output_dir.glob("*"):
49
+ try:
50
+ file.unlink()
51
+ except Exception as e:
52
+ logger.warning(f"Could not delete {file}: {e}")
53
+
54
+ def generate_text_with_groq(self, prompt, max_tokens=200):
55
+ """Generate text with error handling"""
56
+ try:
57
+ chat_completion = groq_client.chat.completions.create(
58
+ messages=[
59
+ {
60
+ "role": "system",
61
+ "content": "You are a professional comedy writer skilled in creating short, witty scripts."
62
+ },
63
+ {
64
+ "role": "user",
65
+ "content": prompt
66
+ }
67
+ ],
68
+ model="mixtral-8x7b-32768",
69
+ max_tokens=max_tokens,
70
+ temperature=0.7,
71
+ )
72
+ return chat_completion.choices[0].message.content
73
+ except Exception as e:
74
+ logger.error(f"Error generating text: {str(e)}")
75
+ raise
76
+
77
+ def generate_speech(self, text):
78
+ """Generate speech with unique filenames"""
79
+ try:
80
+ output_path = self.output_dir / f"speech_{hash(text)}.wav"
81
+ tts.tts_to_file(text=text, file_path=str(output_path))
82
+ return str(output_path)
83
+ except Exception as e:
84
+ logger.error(f"Error generating speech: {str(e)}")
85
+ raise
86
 
87
+ def generate_video_frames(self, prompt, num_frames=15):
88
+ """Generate video frames with progress tracking"""
89
+ frames = []
90
+ try:
91
+ for i in range(num_frames):
92
+ frame_prompt = f"{prompt}, frame {i+1} of {num_frames}"
93
+ with torch.no_grad():
94
+ image = pipe(
95
+ frame_prompt,
96
+ num_inference_steps=20,
97
+ guidance_scale=7.5
98
+ ).images[0]
99
+ frames.append(np.array(image))
100
+ logger.info(f"Generated frame {i+1}/{num_frames}")
101
+ except Exception as e:
102
+ logger.error(f"Error generating frames: {str(e)}")
103
+ raise
104
+ return frames
105
+
106
+ def create_video_from_frames(self, frames, prompt):
107
+ """Create video with unique filenames"""
108
+ try:
109
+ output_path = self.output_dir / f"video_{hash(prompt)}.mp4"
110
+ frames_tensor = torch.from_numpy(np.array(frames)).permute(0, 3, 1, 2)
111
+ write_video(str(output_path), frames_tensor, fps=8)
112
+ return str(output_path)
113
+ except Exception as e:
114
+ logger.error(f"Error creating video: {str(e)}")
115
+ raise
116
+
117
+ def generate_comedy_animation(self, prompt):
118
+ """Generate comedy animation with error handling"""
119
+ try:
120
+ self.cleanup_old_files()
121
+
122
+ script_prompt = f"""Write a short, witty comedy script with two characters about {prompt}.
123
+ Use the format 'Character: Dialogue or Action' for each line.
124
+ Include clever wordplay, unexpected twists, and snappy dialogue.
125
+ Keep it concise, around 5-8 exchanges. Make it genuinely funny!"""
126
+
127
+ script = self.generate_text_with_groq(script_prompt)
128
+ video_prompt = f"A comedic scene with two characters: {prompt}"
129
+ frames = self.generate_video_frames(video_prompt)
130
+ video_path = self.create_video_from_frames(frames, video_prompt)
131
+ speech_path = self.generate_speech(script)
132
+
133
+ return script, video_path, speech_path
134
+ except Exception as e:
135
+ logger.error(f"Error in comedy animation generation: {str(e)}")
136
+ return "Error generating content", None, None
137
+
138
+ def generate_kids_music_animation(self, theme):
139
+ """Generate kids music animation with error handling"""
140
+ try:
141
+ self.cleanup_old_files()
142
+
143
+ lyrics_prompt = f"""Write short, catchy, and simple lyrics for a children's song about {theme}.
144
+ Each line should be on a new line. Don't include 'Verse' or 'Chorus' labels.
145
+ Make it educational, fun, and easy to remember. Include a repeating chorus."""
146
+
147
+ lyrics = self.generate_text_with_groq(lyrics_prompt)
148
+ video_prompt = f"A colorful, animated music video for children about {theme}"
149
+ frames = self.generate_video_frames(video_prompt)
150
+ video_path = self.create_video_from_frames(frames, video_prompt)
151
+ speech_path = self.generate_speech(lyrics)
152
+
153
+ return lyrics, video_path, speech_path
154
+ except Exception as e:
155
+ logger.error(f"Error in kids music animation generation: {str(e)}")
156
+ return "Error generating content", None, None
157
+
158
+ # Initialize content generator
159
+ generator = ContentGenerator()
160
 
161
  # Gradio Interface
162
+ with gr.Blocks(theme='ysharma/steampunk') as app:
163
+ gr.Markdown("## AI-Generated Video and Audio Content")
164
+
165
+ # Status message for errors
166
+ status_msg = gr.Textbox(label="Status", visible=False)
167
 
168
  with gr.Tab("Comedy Animation"):
169
  comedy_prompt = gr.Textbox(label="Enter comedy prompt")
 
172
  comedy_animation = gr.Video(label="Comedy Animation")
173
  comedy_audio = gr.Audio(label="Comedy Speech")
174
 
175
+ def comedy_wrapper(prompt):
176
+ status_msg.visible = True
177
+ try:
178
+ return generator.generate_comedy_animation(prompt)
179
+ except Exception as e:
180
+ status_msg.value = f"Error: {str(e)}"
181
+ return None, None, None
182
+
183
  comedy_generate_btn.click(
184
+ comedy_wrapper,
185
  inputs=comedy_prompt,
186
  outputs=[comedy_script, comedy_animation, comedy_audio]
187
  )
 
193
  music_animation = gr.Video(label="Music Animation")
194
  music_audio = gr.Audio(label="Music Audio")
195
 
196
+ def music_wrapper(theme):
197
+ status_msg.visible = True
198
+ try:
199
+ return generator.generate_kids_music_animation(theme)
200
+ except Exception as e:
201
+ status_msg.value = f"Error: {str(e)}"
202
+ return None, None, None
203
+
204
  music_generate_btn.click(
205
+ music_wrapper,
206
  inputs=music_theme,
207
  outputs=[music_lyrics, music_animation, music_audio]
208
  )
209
 
210
+ if __name__ == "__main__":
211
+ app.launch()
212