audio_palette / app.py
manasch's picture
add ngrok endpoint as input
8865845 verified
raw
history blame
No virus
3.35 kB
from pathlib import Path
import numpy as np
import gradio as gr
import PIL
from lib.audio_generation import AudioGeneration
from lib.image_captioning import ImageCaptioning
from lib.pace_model import PaceModel
pace_model_weights_path = (Path.cwd() / "models" / "pace_model_weights.h5").resolve()
resnet50_tf_model_weights_path = (Path.cwd() / "models" / "resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5")
height, width, channels = (224, 224, 3)
class AudioPalette:
def __init__(self):
self.pace_model = PaceModel(height, width, channels, resnet50_tf_model_weights_path, pace_model_weights_path)
self.image_captioning = ImageCaptioning()
self.audio_generation = AudioGeneration()
def generate(self, input_image: PIL.Image.Image, ngrok_endpoint: str):
pace = self.pace_model.predict(input_image)
print("Pace Prediction Done")
generated_text = self.image_captioning.query(input_image)[0].get("generated_text")
print("Captioning Done")
generated_text = generated_text if generated_text is not None else ""
prompt = f"Generate a soundtrack for {generated_text} with {pace} beats and the instrument of choice is the guitar, High quality"
audio_file = self.audio_generation.generate(prompt, ngrok_endpoint)
print("Audio Generation Done")
outputs = [prompt, pace, generated_text, audio_file]
return outputs
def main():
model = AudioPalette()
demo = gr.Interface(
fn=model.generate,
inputs=[
gr.Image(
type="pil",
label="Upload an image",
show_label=True,
container=True
),
gr.Textbox(
lines=1,
placeholder="ngrok endpoint",
label="colab endpoint",
show_label=True,
container=True,
type="text",
visible=True
)
],
outputs=[
gr.Textbox(
lines=1,
placeholder="Prompt",
label="Generated Prompt",
show_label=True,
container=True,
type="text",
visible=True
),
gr.Textbox(
lines=1,
placeholder="Pace of the image",
label="Pace",
show_label=True,
container=True,
type="text",
visible=False
),
gr.Textbox(
lines=1,
placeholder="Caption for the image",
label="Caption",
show_label=True,
container=True,
type="text",
visible=False
),
gr.Audio(
label="Generated Audio",
show_label=True,
container=True,
visible=True,
format="wav",
autoplay=False,
show_download_button=True,
)
],
cache_examples=False,
live=False,
title="Audio Palette",
description="Provide an image to generate an appropriate background soundtrack",
)
demo.queue().launch()
if __name__ == "__main__":
main()