Spaces:
Sleeping
Sleeping
File size: 2,479 Bytes
afa22d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria
from threading import Thread
base_model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
model = AutoModelForCausalLM.from_pretrained(base_model_name, low_cpu_mem_usage=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device=device)
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += f"\n<|user|>\n{user_prompt}</s>"
prompt += f"\n<|assistant|>\n{bot_response}</s>"
prompt += f"\n<|user|>\n{message}</s>\n<|assistant|>\n"
return prompt
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def generate(prompt, history):
formatted_prompt = format_prompt(prompt, history)
input_ids = tokenizer([formatted_prompt], return_tensors="pt").to(device)
stop_criteria = StoppingCriteriaList([StopOnTokens()])
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.95, top_k=50,
temperature=0.5, num_beams=1, stopping_criteria=stop_criteria)
thread = Thread(target=model.generate, kwargs=generation_kwargs )
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
if '</s>' in generated_text:
break
yield generated_text
mychatbot = gr.Chatbot(
avatar_images=["user.png", "botl.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
demo = gr.ChatInterface(fn=generate,
chatbot=mychatbot,
title=" Tomoniai's Tinyllama Chat ",
description=" Tiny but an awesome model. The response may be slow for cpu environments. Try with gpu for faster answers.",
retry_btn=None,
undo_btn=None
)
demo.queue().launch(show_api=False)
|