Spaces:
Running
on
Zero
Running
on
Zero
| """A simple web interactive chat demo based on gradio.""" | |
| from argparse import ArgumentParser | |
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| TextIteratorStreamer, | |
| ) | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
| ) -> bool: | |
| stop_ids = ( | |
| [2, 6, 7, 8], | |
| ) # "<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>" | |
| for stop_id in stop_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| class StoppingCriteriaSub(StoppingCriteria): | |
| def __init__(self, stops = [], encounters=1): | |
| super().__init__() | |
| self.stops = [stop.to("cuda") for stop in stops] | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
| last_token = input_ids[0][-1] | |
| for stop in self.stops: | |
| if tokenizer.decode(stop) == tokenizer.decode(last_token): | |
| return True | |
| return False | |
| def parse_text(text): | |
| lines = text.split("\n") | |
| lines = [line for line in lines if line != ""] | |
| count = 0 | |
| for i, line in enumerate(lines): | |
| if "```" in line: | |
| count += 1 | |
| items = line.split("`") | |
| if count % 2 == 1: | |
| lines[i] = f'<pre><code class="language-{items[-1]}">' | |
| else: | |
| lines[i] = f"<br></code></pre>" | |
| else: | |
| if i > 0: | |
| if count % 2 == 1: | |
| line = line.replace("`", "\`") | |
| line = line.replace("<", "<") | |
| line = line.replace(">", ">") | |
| line = line.replace(" ", " ") | |
| line = line.replace("*", "*") | |
| line = line.replace("_", "_") | |
| line = line.replace("-", "-") | |
| line = line.replace(".", ".") | |
| line = line.replace("!", "!") | |
| line = line.replace("(", "(") | |
| line = line.replace(")", ")") | |
| line = line.replace("$", "$") | |
| lines[i] = "<br>" + line | |
| text = "".join(lines) | |
| return text | |
| def predict(history, max_length, top_p, temperature): | |
| stop = StopOnTokens() | |
| # messages = [{"role": "system", "content": "You are a helpful assistant"}] | |
| messages = [{"role": "system", "content": ""}] | |
| # messages = [] | |
| for idx, (user_msg, model_msg) in enumerate(history): | |
| if idx == len(history) - 1 and not model_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| break | |
| if user_msg: | |
| messages.append({"role": "user", "content": user_msg}) | |
| if model_msg: | |
| messages.append({"role": "assistant", "content": model_msg}) | |
| print("\n\n====conversation====\n", messages) | |
| model_inputs = tokenizer.apply_chat_template( | |
| messages, add_generation_prompt=True, tokenize=True, return_tensors="pt" | |
| ).to(next(model.parameters()).device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=60, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| # stop_words = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|im_sep|>"] | |
| stop_words = ["</s>"] | |
| stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words] | |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) | |
| generate_kwargs = { | |
| "input_ids": model_inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_length, | |
| "do_sample": True, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "stopping_criteria": stopping_criteria, | |
| "repetition_penalty": 1.1, | |
| } | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| for new_token in streamer: | |
| if new_token != "": | |
| history[-1][1] += new_token | |
| yield history | |
| def main(args): | |
| with gr.Blocks() as demo: | |
| # gr.Markdown( | |
| # """\ | |
| # <p align="center"><img src="https://raw.githubusercontent.com/01-ai/Yi/main/assets/img/Yi_logo_icon_light.svg" style="height: 80px"/><p>""" | |
| # ) | |
| # gr.Markdown("""<center><font size=8>Yi-Chat Bot</center>""") | |
| gr.Markdown("""<center><font size=8>🦣MAmmoTH2</center>""") | |
| # gr.Markdown( | |
| # """\ | |
| # <center><font size=3>This WebUI is based on Yi-Chat, developed by 01-AI.</center>""" | |
| # ) | |
| gr.Markdown( | |
| """\ | |
| <center><font size=4> | |
| MAmmoTH2-8x7B-Plus <a style="text-decoration: none" href="https://huggingface.co/TIGER-Lab/MAmmoTH2-8x7B-Plus/">🤗</a> """ | |
| # <a style="text-decoration: none" href="https://www.modelscope.cn/models/01ai/Yi-34B-Chat/summary">🤖</a>  | |
| #  <a style="text-decoration: none" href="https://github.com/01-ai/Yi">Yi GitHub</a></center> | |
| ) | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Column(scale=12): | |
| user_input = gr.Textbox( | |
| show_label=False, | |
| placeholder="Input...", | |
| lines=10, | |
| container=False, | |
| ) | |
| with gr.Column(min_width=32, scale=1): | |
| submitBtn = gr.Button("🚀 Submit") | |
| with gr.Column(scale=1): | |
| emptyBtn = gr.Button("🧹 Clear History") | |
| max_length = gr.Slider( | |
| 0, | |
| 32768, | |
| value=4096, | |
| step=1.0, | |
| label="Maximum length", | |
| interactive=True, | |
| ) | |
| top_p = gr.Slider( | |
| 0, 1, value=1.0, step=0.01, label="Top P", interactive=True | |
| ) | |
| temperature = gr.Slider( | |
| 0.01, 1, value=0.7, step=0.01, label="Temperature", interactive=True | |
| ) | |
| def user(query, history): | |
| # return "", history + [[parse_text(query), ""]] | |
| return "", history + [[query, ""]] | |
| submitBtn.click( | |
| user, [user_input, chatbot], [user_input, chatbot], queue=False | |
| ).then(predict, [chatbot, max_length, top_p, temperature], chatbot) | |
| user_input.submit( | |
| user, [user_input, chatbot], [user_input, chatbot], queue=False | |
| ).then(predict, [chatbot, max_length, top_p, temperature], chatbot) | |
| emptyBtn.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue() | |
| demo.launch( | |
| server_name=args.server_name, | |
| server_port=args.server_port, | |
| inbrowser=args.inbrowser, | |
| share=args.share | |
| ) | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "-c", | |
| "--checkpoint-path", | |
| type=str, | |
| default="TIGER-Lab/MAmmoTH2-8B-Plus", | |
| help="Checkpoint name or path, default to %(default)r", | |
| ) | |
| parser.add_argument( | |
| "--cpu-only", action="store_true", help="Run demo with CPU only" | |
| ) | |
| parser.add_argument( | |
| "--share", | |
| action="store_true", | |
| default=False, | |
| help="Create a publicly shareable link for the interface.", | |
| ) | |
| parser.add_argument( | |
| "--inbrowser", | |
| action="store_true", | |
| default=True, | |
| help="Automatically launch the interface in a new tab on the default browser.", | |
| ) | |
| parser.add_argument( | |
| "--server-port", type=int, default=8110, help="Demo server port." | |
| ) | |
| parser.add_argument( | |
| "--server-name", type=str, default="127.0.0.1", help="Demo server name." | |
| ) | |
| args = parser.parse_args() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| args.checkpoint_path, trust_remote_code=True | |
| ) | |
| if args.cpu_only: | |
| device_map = "cpu" | |
| else: | |
| device_map = "auto" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.checkpoint_path, | |
| device_map=device_map, | |
| torch_dtype="auto", | |
| trust_remote_code=True, | |
| ).eval() | |
| main(args) | |