shisa / app.py
leonardlin's picture
working streaming interface
0e02ca5
raw
history blame
4.99 kB
# https://www.gradio.app/guides/using-hugging-face-integrations
import gradio as gr
import logging
import html
import time
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Model
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v0.3"
model_name = "/models/llm/hf/mistralai_Mistral-7B-Instruct-v0.1"
# UI Settings
title = "Shisa 7B"
description = "Test out Shisa 7B in either English or Japanese."
placeholder = "Type Here / ここにε…₯εŠ›γ—γ¦γγ γ•γ„"
examples = [
"Hello, how are you?",
"γ“γ‚“γ«γ‘γ―γ€ε…ƒζ°—γ§γ™γ‹οΌŸ",
"γŠγ£γ™γ€ε…ƒζ°—οΌŸ",
"γ“γ‚“γ«γ‘γ―γ€γ„γ‹γŒγŠιŽγ”γ—γ§γ™γ‹οΌŸ",
]
# LLM Settings
system_prompt = 'You are a helpful, friendly assistant.'
chat_history = [{"role": "system", "content": system_prompt}]
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.chat_template = "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <<SYS>>\\n' + messages[idx]['content'] + '\\n<</SYS>>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}\n"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
load_in_8bit=True,
)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
def chat(message, history):
chat_history.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt").to('cuda')
generate_kwargs = dict(
inputs=input_ids,
streamer=streamer,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.95,
eos_token_id=tokenizer.eos_token_id,
)
# https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token # html.escape(new_token)
yield partial_message
'''
# https://www.gradio.app/main/guides/creating-a-chatbot-fast#streaming-chatbots
for i in range(len(message)):
time.sleep(0.3)
yield message[: i+1]
'''
chat_interface = gr.ChatInterface(
chat,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
title=title,
description=description,
theme="soft",
examples=examples,
cache_examples=False,
undo_btn="Delete Previous",
clear_btn="Clear",
)
# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
with gr.Blocks() as demo:
chat_interface.render()
gr.Markdown("You can try these greetings in English, Japanese, familiar Japanese, or formal Japanese. We limit output to 200 tokens.")
demo.queue().launch()
'''
# Works for Text input...
demo = gr.Interface.from_pipeline(pipe)
'''
'''
def chat(message, history):
print("foo")
for i in range(len(message)):
time.sleep(0.3)
yield "You typed: " + message[: i+1]
# print('history:', history)
# print('message:', message)
# for new_next in streamer:
# yield new_text
'''
'''
# Docs: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/conversational.py
conversation = Conversation()
conversation.add_message({"role": "system", "content": system})
device = torch.device('cuda')
pipe = pipeline(
'conversational',
model=model,
tokenizer=tokenizer,
streamer=streamer,
)
def chat(input, history):
conversation.add_message({"role": "user", "content": input})
# we do this shuffle so local shadow response doesn't get created
response_conversation = pipe(conversation)
print("foo:", response_conversation.messages[-1]["content"])
conversation.add_message(response_conversation.messages[-1])
print("boo:", response_conversation.messages[-1]["content"])
response = conversation.messages[-1]["content"]
response = "ping"
return response
demo = gr.ChatInterface(
chat,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
title=title,
description=description,
theme="soft",
examples=examples,
cache_examples=False,
undo_btn="Delete Previous",
clear_btn="Clear",
).launch()
# For async
# ).queue().launch()
'''