import torch.cuda
import gradio as gr
import mdtex2html
import tempfile
from PIL import Image
import scipy
import argparse
from llama.m2ugen import M2UGen
import llama
import numpy as np
import os
import torch
import torchaudio
import torchvision.transforms as transforms
import av
import subprocess
import librosa
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", default="./ckpts/checkpoint.pth", type=str,
help="Name of or path to M2UGen pretrained checkpoint",
)
parser.add_argument(
"--llama_type", default="7B", type=str,
help="Type of llama original weight",
)
parser.add_argument(
"--llama_dir", default="/path/to/llama", type=str,
help="Path to LLaMA pretrained checkpoint",
)
parser.add_argument(
"--mert_path", default="m-a-p/MERT-v1-330M", type=str,
help="Path to MERT pretrained checkpoint",
)
parser.add_argument(
"--vit_path", default="m-a-p/MERT-v1-330M", type=str,
help="Path to ViT pretrained checkpoint",
)
parser.add_argument(
"--vivit_path", default="m-a-p/MERT-v1-330M", type=str,
help="Path to ViViT pretrained checkpoint",
)
parser.add_argument(
"--knn_dir", default="./ckpts", type=str,
help="Path to directory with KNN Index",
)
parser.add_argument(
'--music_decoder', default="musicgen", type=str,
help='Decoder to use musicgen/audioldm2')
parser.add_argument(
'--music_decoder_path', default="facebook/musicgen-medium", type=str,
help='Path to decoder to use musicgen/audioldm2')
args = parser.parse_args()
generated_audio_files = []
llama_type = args.llama_type
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
llama_tokenzier_path = args.llama_dir
model = M2UGen(llama_ckpt_dir, llama_tokenzier_path, args, knn=False, stage=None, load_llama=False)
print("Loading Model Checkpoint")
checkpoint = torch.load(args.model, map_location='cpu')
new_ckpt = {}
for key, value in checkpoint['model'].items():
if "generation_model" in key:
continue
key = key.replace("module.", "")
new_ckpt[key] = value
load_result = model.load_state_dict(new_ckpt, strict=False)
assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
model.eval()
model.to("cuda")
#model.generation_model.to("cuda")
#model.mert_model.to("cuda")
#model.vit_model.to("cuda")
#model.vivit_model.to("cuda")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)])
def postprocess(self, y):
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (
None if message is None else mdtex2html.convert((message)),
None if response is None else mdtex2html.convert(response),
)
return y
gr.Chatbot.postprocess = postprocess
def parse_text(text, image_path, video_path, audio_path):
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
outputs = text
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'
'
else:
lines[i] = f'
'
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", "\`")
line = line.replace("<", "<")
line = line.replace(">", ">")
line = line.replace(" ", " ")
line = line.replace("*", "*")
line = line.replace("_", "_")
line = line.replace("-", "-")
line = line.replace(".", ".")
line = line.replace("!", "!")
line = line.replace("(", "(")
line = line.replace(")", ")")
line = line.replace("$", "$")
lines[i] = " " + line
text = "".join(lines) + " "
if image_path is not None:
text += f' '
outputs = f'{image_path} ' + outputs
if video_path is not None:
text += f'