from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import gradio as gr import torch MAX_INPUT_TOKEN_LENGTH = 4096 model_id = 'HuggingFaceH4/zephyr-7b-beta' model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map='auto') tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.use_default_system_prompt = False def generate(input, chat_history=[], system_prompt=False, max_new_tokens=512, temperature=0.5, top_p=0.95, top_k=50, repetition_penalty=1.2): conversation = [] if system_prompt: conversation.append({ 'role': 'system', 'content': system_prompt }) for user, assistant in chat_history: conversation.extend({ 'role': 'user', 'content': user }, { 'role': 'assistant', 'content': assistant }) conversation.append({ 'role': 'user', 'content': input }) input_ids = tokenizer.apply_chat_template(conversation, return_tensors='pt') if input_ids.shape[1] > MAXX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {'input_ids': input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield ''.join(outputs) chat_interface = gr.ChatInterface( fn=generate, examples=[ 'What is GPT?', 'What is Life?', 'Who is Alan Turing' ] ) chat_interface.queue(max_size=20).launch()