Spaces:
aiqcamp
/
Running on Zero

AudioLlama / app.py
aiqcamp's picture
Update app.py
1354a37 verified
raw
history blame
8 kB
import spaces
import logging
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
import torchaudio
import os
from transformers import pipeline
from pixabay import Image, Video
import tempfile
# ๊ธฐ๋ณธ ์„ค์ •
try:
import mmaudio
except ImportError:
os.system("pip install -e .")
import mmaudio
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
setup_eval_logging)
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.sequence_config import SequenceConfig
from mmaudio.model.utils.features_utils import FeaturesUtils
# CUDA ์„ค์ •
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ๋กœ๊น… ์„ค์ •
log = logging.getLogger()
# ์žฅ์น˜ ๋ฐ ๋ฐ์ดํ„ฐ ํƒ€์ž… ์„ค์ •
device = 'cuda'
dtype = torch.bfloat16
# ๋ชจ๋ธ ์„ค์ •
model: ModelConfig = all_model_cfg['large_44k_v2']
model.download_if_needed()
output_dir = Path('./output/gradio')
setup_eval_logging()
# ๋ฒˆ์—ญ๊ธฐ ๋ฐ Pixabay API ์„ค์ •
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
import requests
def search_pixabay_videos(query, api_key):
base_url = "https://pixabay.com/api/videos/"
params = {
"key": api_key,
"q": query,
"per_page": 80
}
response = requests.get(base_url, params=params)
if response.status_code == 200:
data = response.json()
return [video['videos']['large']['url'] for video in data.get('hits', [])]
return []
# ๋ฉ”์ธ ์ฝ”๋“œ์—์„œ pixabay ๊ด€๋ จ ๋ถ€๋ถ„ ์ˆ˜์ •
PIXABAY_API_KEY = "33492762-a28a596ec4f286f84cd328b17"
def search_videos(query):
query = translate_prompt(query)
return search_pixabay_videos(query, PIXABAY_API_KEY)
# CSS ์Šคํƒ€์ผ ์ •์˜
custom_css = """
.gradio-container {
background: linear-gradient(45deg, #1a1a1a, #2a2a2a);
border-radius: 15px;
box-shadow: 0 8px 32px rgba(0,0,0,0.3);
}
.input-container, .output-container {
background: rgba(255,255,255,0.1);
backdrop-filter: blur(10px);
border-radius: 10px;
padding: 20px;
transform-style: preserve-3d;
transition: transform 0.3s ease;
}
.input-container:hover {
transform: translateZ(20px);
}
.gallery-item {
transition: transform 0.3s ease;
border-radius: 8px;
overflow: hidden;
}
.gallery-item:hover {
transform: scale(1.05);
box-shadow: 0 4px 15px rgba(0,0,0,0.2);
}
.tabs {
background: rgba(255,255,255,0.05);
border-radius: 10px;
padding: 10px;
}
button {
background: linear-gradient(45deg, #4a90e2, #357abd);
border: none;
border-radius: 5px;
transition: all 0.3s ease;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 15px rgba(74,144,226,0.3);
}
"""
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
seq_cfg = model.seq_cfg
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
log.info(f'Loaded weights from {model.model_path}')
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False)
feature_utils = feature_utils.to(device, dtype).eval()
return net, feature_utils, seq_cfg
net, feature_utils, seq_cfg = get_model()
def translate_prompt(text):
if text and any(ord(char) >= 0x3131 and ord(char) <= 0xD7A3 for char in text):
translation = translator(text)[0]['translation_text']
return translation
return text
def search_videos(query):
query = translate_prompt(query)
videos = pixabay_video.search(q=query, per_page=80)
return [video.video_large for video in videos['hits']]
@spaces.GPU
@torch.inference_mode()
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
cfg_strength: float, duration: float):
prompt = translate_prompt(prompt)
negative_prompt = translate_prompt(negative_prompt)
rng = torch.Generator(device=device)
rng.manual_seed(seed)
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
clip_frames, sync_frames, duration = load_video(video, duration)
clip_frames = clip_frames.unsqueeze(0)
sync_frames = sync_frames.unsqueeze(0)
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audios = generate(clip_frames,
sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength)
audio = audios.float().cpu()[0]
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
make_video(video,
video_save_path,
audio,
sampling_rate=seq_cfg.sampling_rate,
duration_sec=seq_cfg.duration)
return video_save_path
@spaces.GPU
@torch.inference_mode()
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
duration: float):
prompt = translate_prompt(prompt)
negative_prompt = translate_prompt(negative_prompt)
rng = torch.Generator(device=device)
rng.manual_seed(seed)
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
clip_frames = sync_frames = None
seq_cfg.duration = duration
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
audios = generate(clip_frames,
sync_frames, [prompt],
negative_text=[negative_prompt],
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg_strength)
audio = audios.float().cpu()[0]
audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
return audio_save_path
# ์ธํ„ฐํŽ˜์ด์Šค ์ •์˜
video_search_tab = gr.Interface(
fn=search_videos,
inputs=gr.Textbox(label="๊ฒ€์ƒ‰์–ด ์ž…๋ ฅ"),
outputs=gr.Gallery(label="๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ", columns=4, rows=20),
css=custom_css
)
video_to_audio_tab = gr.Interface(
fn=video_to_audio,
inputs=[
gr.Video(label="๋น„๋””์˜ค"),
gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ"),
gr.Textbox(label="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ", value="music"),
gr.Number(label="์‹œ๋“œ", value=0),
gr.Number(label="์Šคํ… ์ˆ˜", value=25),
gr.Number(label="๊ฐ€์ด๋“œ ๊ฐ•๋„", value=4.5),
gr.Number(label="๊ธธ์ด(์ดˆ)", value=8),
],
outputs="playable_video",
css=custom_css
)
text_to_audio_tab = gr.Interface(
fn=text_to_audio,
inputs=[
gr.Textbox(label="ํ”„๋กฌํ”„ํŠธ"),
gr.Textbox(label="๋„ค๊ฑฐํ‹ฐ๋ธŒ ํ”„๋กฌํ”„ํŠธ"),
gr.Number(label="์‹œ๋“œ", value=0),
gr.Number(label="์Šคํ… ์ˆ˜", value=25),
gr.Number(label="๊ฐ€์ด๋“œ ๊ฐ•๋„", value=4.5),
gr.Number(label="๊ธธ์ด(์ดˆ)", value=8),
],
outputs="audio",
css=custom_css
)
# ๋ฉ”์ธ ์‹คํ–‰
if __name__ == "__main__":
gr.TabbedInterface(
[video_search_tab, video_to_audio_tab, text_to_audio_tab],
["๋น„๋””์˜ค ๊ฒ€์ƒ‰", "๋น„๋””์˜ค-์˜ค๋””์˜ค ๋ณ€ํ™˜", "ํ…์ŠคํŠธ-์˜ค๋””์˜ค ๋ณ€ํ™˜"],
css=custom_css
).launch(allowed_paths=[output_dir])