Tinyllama_Chat / app.py
Tomoniai's picture
Update app.py
bbe36b7 verified
raw
history blame contribute delete
No virus
2.56 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria
from threading import Thread
base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device=device)
def format_prompt(message, history):
prompt = "<|system|>\nYou are TinyLlama, a friendly AI assistant.</s>"
for user_prompt, bot_response in history:
prompt += f"\n<|user|>\n{user_prompt}</s>"
prompt += f"\n<|assistant|>\n{bot_response}</s>"
prompt += f"\n<|user|>\n{message}</s>\n<|assistant|>\n"
return prompt
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def generate(prompt, history):
formatted_prompt = format_prompt(prompt, history)
input_ids = tokenizer([formatted_prompt], return_tensors="pt").to(device)
stop_criteria = StoppingCriteriaList([StopOnTokens()])
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=50,
temperature=0.5, num_beams=1, stopping_criteria=stop_criteria)
thread = Thread(target=model.generate, kwargs=generation_kwargs )
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if '</s>' in generated_text:
break
yield generated_text
mychatbot = gr.Chatbot(
avatar_images=["user.png", "botl.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
demo = gr.ChatInterface(fn=generate,
chatbot=mychatbot,
title=" Tomoniai's Tinyllama Chat ",
description=" Tiny but an awesome model. The response may be slow for cpu environments. Try with gpu for faster answers.",
retry_btn=None,
undo_btn=None
)
demo.queue().launch(show_api=False)