import os import gc import torch import torch.nn as nn import argparse import gradio as gr from transformers import AutoTokenizer, LlamaForCausalLM from utils import SteamGenerationMixin auth_token = os.getenv("Zimix") print('^_^ auth_token:',os.getenv("Zimix"),'!!!!!!!!!!') print('^_^:secret_token',os.getenv("SECRET_TOKEN"),'!!!!!!!!!!') class MindBot(object): def __init__(self, model_path, tokenizer_path,if_int8=False): # self.device = torch.device("cuda") # device_ids = [1, 2] if if_int8: self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto', load_in_8bit=True,use_auth_token=auth_token).eval() else: self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto',use_auth_token=auth_token).half().eval() self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,use_auth_token=auth_token) # sp_tokens = {'additional_special_tokens': ['', '']} # self.tokenizer.add_special_tokens(sp_tokens) self.history = [] def build_prompt(self, instruction, history, human='', bot=''): pmt = '' if len(history) > 0: for line in history: pmt += f'{human}: {line[0].strip()}\n{bot}: {line[1]}\n' pmt += f'{human}: {instruction.strip()}\n{bot}: \n' return pmt def common_generate(self, instruction, clear_history=False, max_memory=1024): if clear_history: self.history = [] prompt = self.build_prompt(instruction, self.history) input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids if input_ids.shape[1] > max_memory: input_ids = input_ids[:, -max_memory:] prompt_len = input_ids.shape[1] # common method generation_output = self.model.generate( input_ids.cuda(), max_new_tokens=1024, do_sample=True, top_p=0.85, temperature=0.8, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0 ) s = generation_output[0][prompt_len:] output = self.tokenizer.decode(s, skip_special_tokens=True) # output = output output = output.replace("Belle", "IDEA") self.history.append((instruction, output)) print('api history: ======> \n', self.history) return output def interaction( self, instruction, history, max_memory=1024 ): prompt = self.build_prompt(instruction, history) input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids if input_ids.shape[1] > max_memory: input_ids = input_ids[:, -max_memory:] prompt_len = input_ids.shape[1] # stream generation method try: tmp = history.copy() output = '' with torch.no_grad(): for generation_output in self.model.stream_generate( input_ids.cuda(), max_new_tokens=1024, do_sample=True, top_p=0.85, temperature=0.8, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0 ): s = generation_output[0][prompt_len:] output = self.tokenizer.decode(s, skip_special_tokens=True) output = output.replace('\n', '
') tmp.append((instruction, output)) yield '', tmp tmp.pop() # gc.collect() # torch.cuda.empty_cache() history.append((instruction, output)) print('input -----> \n', prompt) print('output -------> \n', output) print('history: ======> \n', history) except torch.cuda.OutOfMemoryError: gc.collect() torch.cuda.empty_cache() self.model.empty_cache() return "", history def new_chat_bot(self): with gr.Blocks(title='IDEA Ziya', css=".gradio-container {max-width: 50% !important;} .bgcolor {color: white !important; background: #FFA500 !important;}") as demo: gr.Markdown("

IDEA Ziya

") gr.Markdown("
本页面基于hugging face支持的设备搭建
") with gr.Row(): chatbot = gr.Chatbot(label='Ziya').style(height=500) with gr.Row(): msg = gr.Textbox(label="Input") with gr.Row(): with gr.Column(scale=0.5): clear = gr.Button("Clear") with gr.Column(scale=0.5): submit = gr.Button("Submit", elem_classes='bgcolor') msg.submit(self.interaction, [msg, chatbot], [msg, chatbot]) clear.click(lambda: None, None, chatbot, queue=False) submit.click(self.interaction, [msg, chatbot], [msg, chatbot]) return demo.queue(concurrency_count=5) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( "--model_path", type=str, default="/cognitive_comp/songchao/checkpoints/global_step3200-hf" ) args = parser.parse_args() mind_bot = MindBot(args.model_path) demo = mind_bot.new_chat_bot()