Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
from threading import Thread | |
model_path = 'wannaphong/tongyi-model-v1.1-1b-enth' | |
# Loading the tokenizer and model from Hugging Face's model hub. | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) | |
# using CUDA for an optimal experience | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
# Defining a custom stopping criteria class for the model's text generation. | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [151645] # IDs of tokens where the generation should stop. | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token. | |
return True | |
return False | |
system_role= 'system' | |
user_role = 'user' | |
assistant_role = 'assistant' | |
sft_start_token = "<|im_start|>" | |
sft_end_token = "<|im_end|>" | |
ct_end_token = "<|endoftext|>" | |
system_prompt= \ | |
'You are an AI assistant named TongYip (ทองหยิบ), created by PyThaiNLP. As an AI assistant, you can answer questions in Thai and English. Your responses should be friendly, unbiased, informative, detailed, and faithful.' | |
system_prompt = f"<|im_start|>{system_role}\n{system_prompt}<|im_end|>" | |
# Function to generate model predictions. | |
def predict(message, history): | |
# 初始化对话历史格式 | |
if history is None: | |
history = [] | |
# 在历史中添加当前用户输入,临时设置机器人的回复为空 | |
history_transformer_format = history + [[message, ""]] | |
stop = StopOnTokens() | |
# 格式化输入为模型需要的格式 | |
messages = ( | |
system_prompt | |
+ sft_end_token.join([ | |
sft_end_token.join([ | |
f"\n{sft_start_token}{user_role}\n" + item[0], | |
f"\n{sft_start_token}{assistant_role}\n" + item[1] | |
]) for item in history_transformer_format | |
]) | |
) | |
model_inputs = tokenizer([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=model_inputs["input_ids"], | |
attention_mask=model_inputs["attention_mask"], | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.8, | |
top_k=20, | |
temperature=0.7, | |
num_beams=1, | |
stopping_criteria=StoppingCriteriaList([stop]), | |
repetition_penalty=1.1, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() # Starting the generation in a separate thread. | |
partial_message = "" | |
for new_token in streamer: | |
partial_message += new_token | |
if sft_end_token in partial_message: # Breaking the loop if the stop token is generated. | |
break | |
yield partial_message | |
css = """ | |
full-height { | |
height: 100%; | |
} | |
""" | |
prompt_examples = [ | |
'How to cook a fish?', | |
'Cara memanggang ikan', | |
'วิธีย่างปลา', | |
'Cách nướng cá' | |
] | |
# placeholder = """ | |
# <div style="opacity: 0.5;"> | |
# <img src="https://raw.githubusercontent.com/sail-sg/sailor-llm/main/misc/banner.jpg" style="width:30%;"> | |
# <br>Sailor models are designed to understand and generate text across diverse linguistic landscapes of these SEA regions: | |
# <br>🇮🇩Indonesian, 🇹🇭Thai, 🇻🇳Vietnamese, 🇲🇾Malay, and 🇱🇦Lao. | |
# </div> | |
# """ | |
placeholder = "" | |
chatbot = gr.Chatbot(label='Sailor', placeholder=placeholder) | |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo: | |
# gr.Markdown("""<center><font size=8>Sailor-Chat Bot⚓</center>""") | |
gr.Markdown("""<p align="center"><img src="https://github.com/sail-sg/sailor2/raw/main/misc/sailor2_wide_banner.jpg" style="height: 110px"/><p>""") | |
gr.ChatInterface(predict, chatbot=chatbot, fill_height=True, examples=prompt_examples, css=css) | |
demo.launch() # Launching the web interface. |