import os import torch import gradio as gr import mdtex2html import tempfile from PIL import Image import scipy from llama.m2ugen import M2UGen import llama import numpy as np import os import torch import torchaudio import torchvision.transforms as transforms import av import subprocess import librosa import uuid args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2", "mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400", "music_decoder": "musicgen", "music_decoder_path": "facebook/musicgen-medium"} class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ args = dotdict(args) generated_audio_files = {} llama_type = args.llama_type llama_ckpt_dir = os.path.join(args.llama_dir, llama_type) llama_tokenzier_path = args.llama_dir model = M2UGen(llama_ckpt_dir, llama_tokenzier_path, args, knn=False, stage=None, load_llama=False) print("Loading Model Checkpoint") checkpoint = torch.load(args.model, map_location='cpu') new_ckpt = {} for key, value in checkpoint['model'].items(): if "generation_model" in key: continue key = key.replace("module.", "") new_ckpt[key] = value load_result = model.load_state_dict(new_ckpt, strict=False) assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}" model.eval() transform = transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)]) def postprocess(self, y): if y is None: return [] for i, (message, response) in enumerate(y): y[i] = ( None if message is None else mdtex2html.convert((message)), None if response is None else mdtex2html.convert(response), ) return y gr.Chatbot.postprocess = postprocess def parse_text(text, image_path, video_path, audio_path): """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" outputs = text lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f'
' else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) + "
" if image_path is not None: text += f'
' outputs = f'{image_path} ' + outputs if video_path is not None: text += f'