Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import numpy as np | |
import gradio as gr | |
import PIL | |
from lib.audio_generation import AudioGeneration | |
from lib.image_captioning import ImageCaptioning | |
from lib.pace_model import PaceModel | |
pace_model_weights_path = (Path.cwd() / "models" / "pace_model_weights.h5").resolve() | |
resnet50_tf_model_weights_path = (Path.cwd() / "models" / "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5") | |
height, width, channels = (224, 224, 3) | |
class AudioPalette: | |
def __init__(self): | |
self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path) | |
self.image_captioning = ImageCaptioning() | |
self.audio_generation = AudioGeneration() | |
def generate(self, input_image: PIL.Image.Image, ngrok_endpoint: str): | |
pace = self.pace_model.predict(input_image) | |
print("Pace Prediction Done") | |
generated_text = self.image_captioning.query(input_image)[0].get("generated_text") | |
print("Captioning Done") | |
generated_text = generated_text if generated_text is not None else "" | |
prompt = f"Generate a soundtrack for {generated_text} with {pace} beats and the instrument of choice is the guitar, High quality" | |
audio_file = self.audio_generation.generate(prompt, ngrok_endpoint) | |
print("Audio Generation Done") | |
outputs = [prompt, pace, generated_text, audio_file] | |
return outputs | |
def main(): | |
model = AudioPalette() | |
demo = gr.Interface( | |
fn=model.generate, | |
inputs=[ | |
gr.Image( | |
type="pil", | |
label="Upload an image", | |
show_label=True, | |
container=True | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="ngrok endpoint", | |
label="colab endpoint", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=True | |
) | |
], | |
outputs=[ | |
gr.Textbox( | |
lines=1, | |
placeholder="Prompt", | |
label="Generated Prompt", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=True | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="Pace of the image", | |
label="Pace", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=False | |
), | |
gr.Textbox( | |
lines=1, | |
placeholder="Caption for the image", | |
label="Caption", | |
show_label=True, | |
container=True, | |
type="text", | |
visible=False | |
), | |
gr.Audio( | |
label="Generated Audio", | |
show_label=True, | |
container=True, | |
visible=True, | |
format="wav", | |
autoplay=False, | |
show_download_button=True, | |
) | |
], | |
cache_examples=False, | |
live=False, | |
title="Audio Palette", | |
description="Provide an image to generate an appropriate background soundtrack", | |
) | |
demo.queue().launch() | |
if __name__ == "__main__": | |
main() | |