import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria from threading import Thread base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tokenizer = AutoTokenizer.from_pretrained(base_model_name) model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device=device) def format_prompt(message, history): prompt = "<|system|>\nYou are TinyLlama, a friendly AI assistant." for user_prompt, bot_response in history: prompt += f"\n<|user|>\n{user_prompt}" prompt += f"\n<|assistant|>\n{bot_response}" prompt += f"\n<|user|>\n{message}\n<|assistant|>\n" return prompt class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [2] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def generate(prompt, history): formatted_prompt = format_prompt(prompt, history) input_ids = tokenizer([formatted_prompt], return_tensors="pt").to(device) stop_criteria = StoppingCriteriaList([StopOnTokens()]) streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=50, temperature=0.5, num_beams=1, stopping_criteria=stop_criteria) thread = Thread(target=model.generate, kwargs=generation_kwargs ) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text if '' in generated_text: break yield generated_text mychatbot = gr.Chatbot( avatar_images=["user.png", "botl.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,) demo = gr.ChatInterface(fn=generate, chatbot=mychatbot, title=" Tomoniai's Tinyllama Chat ", description=" Tiny but an awesome model. The response may be slow for cpu environments. Try with gpu for faster answers.", retry_btn=None, undo_btn=None ) demo.queue().launch(show_api=False)