audio_palette / app.py
manasch's picture
add slide show for multiple images
94fb728 verified
raw
history blame
7.7 kB
import typing
from pathlib import Path
import numpy as np
import gradio as gr
import PIL
from PIL import Image
from moviepy.editor import *
from lib.audio_generation import AudioGeneration
from lib.image_captioning import ImageCaptioning
from lib.pace_model import PaceModel
pace_model_weights_path = (Path.cwd() / "models" / "pace_model_weights.h5").resolve()
resnet50_tf_model_weights_path = (Path.cwd() / "models" / "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5")
height, width, channels = (224, 224, 3)
class AudioPalette:
def __init__(self):
self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path)
self.image_captioning = ImageCaptioning()
self.audio_generation = AudioGeneration()
self.pace_map = {
"Fast": "high",
"Medium": "medium",
"Slow": "low"
}
def prompt_construction(self, caption: str, pace: str, instrument: typing.Union[str, None], first: bool = True):
instrument = instrument if instrument is not None else ""
if first:
prompt = f"A {instrument} soundtrack for {caption} with {self.pace_map[pace]} beats per minute. High Quality"
else:
prompt = f"A {instrument} soundtrack for {caption} with {self.pace_map[pace]} beats per minute. High Quality. Transitions smoothely from the previous audio while sounding different."
return prompt
def generate_single(self, input_image: PIL.Image.Image, instrument: typing.Union[str, None], ngrok_endpoint: str):
pace = self.pace_model.predict(input_image)
print("Pace Prediction Done")
generated_text = self.image_captioning.query(input_image)[0].get("generated_text")
print("Captioning Done")
generated_text = generated_text if generated_text is not None else ""
prompt = self.prompt_construction(generated_text, pace, instrument)
print("Generated Prompt:", prompt)
audio_file = self.audio_generation.generate(prompt, ngrok_endpoint)
print("Audio Generation Done")
outputs = [prompt, pace, generated_text, audio_file]
return outputs
def stitch_images(self, file_paths: typing.List[str], audio_paths: typing.List[str]):
clips = [ImageClip(m).set_duration(5) for m in file_paths]
audio_clips = [AudioFileClip(a) for a in audio_paths]
concat_audio = concatenate_audioclips(audio_clips)
new_audio = CompositeAudioClip([concat_audio])
concat_clip = concatenate_videoclips(clips, method="compose")
concat_clip.audio = new_audio
file_name = "generated_video.mp4"
concat_clip.write_videofile(file_name, fps=24)
return file_name
def generate_multiple(self, file_paths: typing.List[str], instrument: typing.Union[str, None], ngrok_endpoint: str):
images = [Image.open(image_path) for image_path in file_paths]
pace = []
generated_text = []
prompts = []
# Extracting the pace for all the images
for image in images:
pace_prediction = self.pace_model.predict(image)
pace.append(pace_prediction)
print("Pace Prediction Done")
# Generating the caption for all the images
for image in images:
caption = self.image_captioning.query(image)[0].get("generated_text")
generated_text.append(caption)
print("Captioning Done")
first = True
for generated_caption, pace_pred in zip(generated_text, pace):
prompts.append(self.prompt_construction(generated_caption, pace_pred, instrument, first))
first = False
print("Generated Prompts: ", prompts)
audio_file = self.audio_generation.generate(prompts, ngrok_endpoint)
print("Audio Generation Done")
video_file = self.stitch_images(file_paths, [audio_file])
return video_file
def single_image_interface(model: AudioPalette):
demo = gr.Interface(
fn=model.generate_single,
inputs=[
gr.Image(
type="pil",
label="Upload an image",
show_label=True,
container=True
),
gr.Radio(
choices=["Piano", "Drums", "Guitar", "Violin", "Flute"],
label="Instrument",
show_label=True,
container=True
),
gr.Textbox(
lines=1,
placeholder="ngrok endpoint",
label="colab endpoint",
show_label=True,
container=True,
type="text",
visible=True
)
],
outputs=[
gr.Textbox(
lines=1,
placeholder="Prompt",
label="Generated Prompt",
show_label=True,
container=True,
type="text",
visible=False
),
gr.Textbox(
lines=1,
placeholder="Pace of the image",
label="Pace",
show_label=True,
container=True,
type="text",
visible=False
),
gr.Textbox(
lines=1,
placeholder="Caption for the image",
label="Caption",
show_label=True,
container=True,
type="text",
visible=False
),
gr.Audio(
label="Generated Audio",
show_label=True,
container=True,
visible=True,
format="wav",
autoplay=False,
show_download_button=True,
)
],
cache_examples=False,
live=False,
description="Provide an image to generate an appropriate background soundtrack",
)
return demo
def multi_image_interface(model: AudioPalette):
demo = gr.Interface(
fn=model.generate_multiple,
inputs=[
gr.File(
file_count="multiple",
file_types=["image"],
type="filepath",
label="Upload images",
show_label=True,
container=True,
visible=True
),
gr.Radio(
choices=["Piano", "Drums", "Guitar", "Violin", "Flute"],
label="Instrument",
show_label=True,
container=True
),
gr.Textbox(
lines=1,
placeholder="ngrok endpoint",
label="colab endpoint",
show_label=True,
container=True,
type="text",
visible=True
)
],
outputs=[
gr.Video(
format="mp4",
label="Generated Video",
show_label=True,
container=True,
visible=True,
autoplay=False,
)
],
cache_examples=False,
live=False,
description="Provide images to generate an a slideshow of the images with appropriate music as background",
)
return demo
def main():
model = AudioPalette()
tab_1 = single_image_interface(model)
tab_2 = multi_image_interface(model)
demo = gr.TabbedInterface([tab_1, tab_2], ["Single Image", "Slide Show"], "Audio Palette")
demo.queue().launch()
if __name__ == "__main__":
main()