MSP / app.py
Ken Lin
Update
cf8e8a6
raw
history blame contribute delete
No virus
1.89 kB
import gradio as gr
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import numpy as np
import torch
from ram import get_transform, inference_tag2text
from ram.models import tag2text
from PIL import Image
title = "Musicalization System of Painting Demo"
description = "Pui Ching Middle School: Musicalization System of Painting Demo"
image_size = 384
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.no_grad()
transform = get_transform(image_size=image_size)
tag2text_model = tag2text(pretrained="tag2text_swin_14m.pth", image_size=image_size, vit='swin_b').eval().to(device)
def generate_music(raw_image, audio_length):
raw_image = Image.fromarray(raw_image)
image = transform(raw_image).unsqueeze(0).to(device)
res = inference_tag2text(image, tag2text_model)
tags = res[0].strip(' ').replace(' ', ' ')
caption = res[2]
print(caption)
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=[caption],
padding=True,
return_tensors="pt",
)
sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate
max_new_tokens = int(frame_rate * audio_length)
audio_values = model.generate(**inputs, max_new_tokens=max_new_tokens)
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
audio_values = audio_values[0, 0].numpy()
return sampling_rate, (audio_values * max_range).astype(np.int16)
iface = gr.Interface(
fn=generate_music,
title=title,
description=description,
inputs=[
gr.Image(label="Painting"),
gr.Slider(5, 30, value=15, step=1, label="Audio length(sec)")
],
outputs=gr.Audio(label='Generated Music'))
iface.launch()