Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria | |
from threading import Thread | |
base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device=device) | |
def format_prompt(message, history): | |
prompt = "" | |
for user_prompt, bot_response in history: | |
prompt += f"\n<|user|>\n{user_prompt}</s>" | |
prompt += f"\n<|assistant|>\n{bot_response}</s>" | |
prompt += f"\n<|user|>\n{message}</s>\n<|assistant|>\n" | |
return prompt | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [2] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def generate(prompt, history): | |
formatted_prompt = format_prompt(prompt, history) | |
input_ids = tokenizer([formatted_prompt], return_tensors="pt").to(device) | |
stop_criteria = StoppingCriteriaList([StopOnTokens()]) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=50, | |
temperature=0.5, num_beams=1, stopping_criteria=stop_criteria) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs ) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
generated_text += new_text | |
if '</s>' in generated_text: | |
break | |
yield generated_text | |
mychatbot = gr.Chatbot( | |
avatar_images=["user.png", "botl.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,) | |
demo = gr.ChatInterface(fn=generate, | |
chatbot=mychatbot, | |
title=" Tomoniai's Tinyllama Chat ", | |
description=" Tiny but an awesome model. The response may be slow for cpu environments. Try with gpu for faster answers.", | |
retry_btn=None, | |
undo_btn=None | |
) | |
demo.queue().launch(show_api=False) | |