Spaces:
Runtime error
Runtime error
import typing | |
from pathlib import Path | |
import numpy as np | |
import gradio as gr | |
import PIL | |
from PIL import Image | |
from moviepy.editor import * | |
from lib.audio_generation import AudioGeneration | |
from lib.image_captioning import ImageCaptioning | |
from lib.pace_model import PaceModel | |
pace_model_weights_path = (Path.cwd() / "models" / "pace_model_weights.h5").resolve() | |
resnet50_tf_model_weights_path = (Path.cwd() / "models" / "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5") | |
height, width, channels = (224, 224, 3) | |
class AudioPalette: | |
def __init__(self): | |
self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path) | |
self.image_captioning = ImageCaptioning() | |
self.audio_generation = AudioGeneration() | |
self.pace_map = { | |
"Fast": "high", | |
"Medium": "medium", | |
"Slow": "low" | |
} | |
def prompt_construction(self, caption: str, pace: str, instrument: typing.Union[str, None], first: bool = True): | |
instrument = instrument if instrument is not None else "" | |
if first: | |
prompt = f"A {instrument} soundtrack for {caption} with {self.pace_map[pace]} beats per minute. High Quality" | |
else: | |
prompt = f"A {instrument} soundtrack for {caption} with {self.pace_map[pace]} beats per minute. High Quality. Transitions smoothely from the previous audio while sounding different." | |
return prompt | |
def generate_single(self, input_image: PIL.Image.Image, instrument: typing.Union[str, None], ngrok_endpoint: str): | |
pace = self.pace_model.predict(input_image) | |
print("Pace Prediction Done") | |
generated_text = self.image_captioning.query(input_image)[0].get("generated_text") | |
print("Captioning Done") | |
generated_text = generated_text if generated_text is not None else "" | |
prompt = self.prompt_construction(generated_text, pace, instrument) | |
print("Generated Prompt:", prompt) | |
audio_file = self.audio_generation.generate(prompt, ngrok_endpoint) | |
print("Audio Generation Done") | |
outputs = [prompt, pace, generated_text, audio_file] | |
return outputs | |
def stitch_images(self, file_paths: typing.List[str], audio_paths: typing.List[str]): | |
clips = [ImageClip(m).set_duration(5) for m in file_paths] | |
audio_clips = [AudioFileClip(a) for a in audio_paths] | |
concat_audio = concatenate_audioclips(audio_clips) | |
new_audio = CompositeAudioClip([concat_audio]) | |
concat_clip = concatenate_videoclips(clips, method="compose") | |
concat_clip.audio = new_audio | |
file_name = "generated_video.mp4" | |
concat_clip.write_videofile(file_name, fps=24) | |
return file_name | |
def generate_multiple(self, file_paths: typing.List[str], instrument: typing.Union[str, None], ngrok_endpoint: str): | |
images = [Image.open(image_path) for image_path in file_paths] | |
pace = [] | |
generated_text = [] | |
prompts = [] | |
# Extracting the pace for all the images | |
for image in images: | |
pace_prediction = self.pace_model.predict(image) | |
pace.append(pace_prediction) | |
print("Pace Prediction Done") | |
# Generating the caption for all the images | |
for image in images: | |
caption = self.image_captioning.query(image)[0].get("generated_text") | |
generated_text.append(caption) | |
print("Captioning Done") | |
first = True | |
for generated_caption, pace_pred in zip(generated_text, pace): | |
prompts.append(self.prompt_construction(generated_caption, pace_pred, instrument, first)) | |
first = False | |
print("Generated Prompts: ", prompts) | |
audio_file = self.audio_generation.generate(prompts, ngrok_endpoint) | |
print("Audio Generation Done") | |
video_file = self.stitch_images(file_paths, [audio_file]) | |
return video_file | |
def single_image_interface(model: AudioPalette): | |
demo = gr.Interface( | |
fn=model.generate_single, | |
inputs=[ | |
gr.Image( | |
type="pil", | |
label="Upload an image", | |
show_label=True, | |
container=True | |
), | |
gr.Radio( | |
choices=["Piano", "Drums", "Guitar", "Violin", "Flute"], | |
label="Instrument", | |
show_label=True, | |
container=True | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="ngrok endpoint", | |
label="colab endpoint", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=True | |
) | |
], | |
outputs=[ | |
gr.Textbox( | |
lines=1, | |
placeholder="Prompt", | |
label="Generated Prompt", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=False | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="Pace of the image", | |
label="Pace", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=False | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="Caption for the image", | |
label="Caption", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=False | |
), | |
gr.Audio( | |
label="Generated Audio", | |
show_label=True, | |
container=True, | |
visible=True, | |
format="wav", | |
autoplay=False, | |
show_download_button=True, | |
) | |
], | |
cache_examples=False, | |
live=False, | |
description="Provide an image to generate an appropriate background soundtrack", | |
) | |
return demo | |
def multi_image_interface(model: AudioPalette): | |
demo = gr.Interface( | |
fn=model.generate_multiple, | |
inputs=[ | |
gr.File( | |
file_count="multiple", | |
file_types=["image"], | |
type="filepath", | |
label="Upload images", | |
show_label=True, | |
container=True, | |
visible=True | |
), | |
gr.Radio( | |
choices=["Piano", "Drums", "Guitar", "Violin", "Flute"], | |
label="Instrument", | |
show_label=True, | |
container=True | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="ngrok endpoint", | |
label="colab endpoint", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=True | |
) | |
], | |
outputs=[ | |
gr.Video( | |
format="mp4", | |
label="Generated Video", | |
show_label=True, | |
container=True, | |
visible=True, | |
autoplay=False, | |
) | |
], | |
cache_examples=False, | |
live=False, | |
description="Provide images to generate an a slideshow of the images with appropriate music as background", | |
) | |
return demo | |
def main(): | |
model = AudioPalette() | |
tab_1 = single_image_interface(model) | |
tab_2 = multi_image_interface(model) | |
demo = gr.TabbedInterface([tab_1, tab_2], ["Single Image", "Slide Show"], "Audio Palette") | |
demo.queue().launch() | |
if __name__ == "__main__": | |
main() | |