from transformers import AutoModel, AutoTokenizer from copy import deepcopy import os import ipdb import gradio as gr import mdtex2html from model.anyToImageVideoAudio import NextGPTModel import torch import json import tempfile from PIL import Image import scipy from config import * import imageio import argparse import re # init the model parser = argparse.ArgumentParser(description='train parameters') parser.add_argument('--model', type=str, default='nextgpt') parser.add_argument('--nextgpt_ckpt_path', type=str) # the delta parameters trained in each stages parser.add_argument('--stage', type=int, default=3) args = parser.parse_args() args = vars(args) args.update(load_config(args)) model = NextGPTModel(**args) delta_ckpt = torch.load(os.path.join(args['nextgpt_ckpt_path'], f'pytorch_model.pt'), map_location=torch.device('cpu')) model.load_state_dict(delta_ckpt, strict=False) model = model.eval().half().cuda() print(f'[!] init the 7b model over ...') g_cuda = torch.Generator(device='cuda').manual_seed(13) filter_value = -float('Inf') min_word_tokens = 10 gen_scale_factor = 4.0 stops_id = [[835]] ENCOUNTERS = 1 load_sd = True generator = g_cuda max_num_imgs = 1 max_num_vids = 1 height = 320 width = 576 max_num_auds = 1 max_length = 246 """Override Chatbot.postprocess""" def postprocess(self, y): if y is None: return [] for i, (message, response) in enumerate(y): y[i] = ( None if message is None else mdtex2html.convert((message)), None if response is None else mdtex2html.convert(response), ) return y gr.Chatbot.postprocess = postprocess def parse_text(text, image_path, video_path, audio_path): """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" outputs = text lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = f'
' else: if i > 0: if count % 2 == 1: line = line.replace("`", "\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line text = "".join(lines) + "
" res_text = '' split_text = re.split(r' <|> ', text) image_path_list, video_path_list, audio_path_list = [], [], [] for st in split_text: if st.startswith(''): pattern = r'Image>(.*?)<\/Image' matches = re.findall(pattern, text) for m in matches: image_path_list.append(m) elif st.startswith('