from pathlib import Path import numpy as np import gradio as gr import PIL from lib.audio_generation import AudioGeneration from lib.image_captioning import ImageCaptioning from lib.pace_model import PaceModel pace_model_weights_path = (Path.cwd() / "models" / "pace_model_weights.h5").resolve() resnet50_tf_model_weights_path = (Path.cwd() / "models" / "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5") height, width, channels = (224, 224, 3) class AudioPalette: def __init__(self): self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path) self.image_captioning = ImageCaptioning() self.audio_generation = AudioGeneration() def generate(self, input_image: PIL.Image.Image, ngrok_endpoint: str): pace = self.pace_model.predict(input_image) print("Pace Prediction Done") generated_text = self.image_captioning.query(input_image)[0].get("generated_text") print("Captioning Done") generated_text = generated_text if generated_text is not None else "" prompt = f"Generate a soundtrack for {generated_text} with {pace} beats and the instrument of choice is the guitar, High quality" audio_file = self.audio_generation.generate(prompt, ngrok_endpoint) print("Audio Generation Done") outputs = [prompt, pace, generated_text, audio_file] return outputs def main(): model = AudioPalette() demo = gr.Interface( fn=model.generate, inputs=[ gr.Image( type="pil", label="Upload an image", show_label=True, container=True ), gr.Textbox( lines=1, placeholder="ngrok endpoint", label="colab endpoint", show_label=True, container=True, type="text", visible=True ) ], outputs=[ gr.Textbox( lines=1, placeholder="Prompt", label="Generated Prompt", show_label=True, container=True, type="text", visible=True ), gr.Textbox( lines=1, placeholder="Pace of the image", label="Pace", show_label=True, container=True, type="text", visible=False ), gr.Textbox( lines=1, placeholder="Caption for the image", label="Caption", show_label=True, container=True, type="text", visible=False ), gr.Audio( label="Generated Audio", show_label=True, container=True, visible=True, format="wav", autoplay=False, show_download_button=True, ) ], cache_examples=False, live=False, title="Audio Palette", description="Provide an image to generate an appropriate background soundtrack", ) demo.queue().launch() if __name__ == "__main__": main()