Taiwan-LLaMa2 / app.py
yentinglin's picture
Update app.py
9077757
raw
history blame
9.1 kB
import os
import gradio as gr
from text_generation import Client
from conversation import get_default_conv_template
from transformers import AutoTokenizer
from pymongo import MongoClient
DB_NAME = os.getenv("MONGO_DBNAME", "taiwan-llm")
USER = os.getenv("MONGO_USER")
PASSWORD = os.getenv("MONGO_PASSWORD")
uri = f"mongodb+srv://{USER}:{PASSWORD}@{DB_NAME}.kvwjiok.mongodb.net/?retryWrites=true&w=majority"
mongo_client = MongoClient(uri)
db = mongo_client[DB_NAME]
conversations_collection = db['conversations']
DESCRIPTION = """
# Language Models for Taiwanese Culture
<p align="center">
✍️ <a href="https://huggingface.co/spaces/yentinglin/Taiwan-LLaMa2" target="_blank">Online Demo</a>
β€’
πŸ€— <a href="https://huggingface.co/yentinglin" target="_blank">HF Repo</a> β€’ 🐦 <a href="https://twitter.com/yentinglin56" target="_blank">Twitter</a> β€’ πŸ“ƒ <a href="https://arxiv.org/pdf/2305.13711.pdf" target="_blank">[Paper Coming Soon]</a>
β€’ πŸ‘¨οΈ <a href="https://github.com/MiuLab/Taiwan-LLaMa/tree/main" target="_blank">Github Repo</a>
<br/><br/>
<img src="https://www.csie.ntu.edu.tw/~miulab/taiwan-llama/logo-v2.png" width="100"> <br/>
</p>
Taiwan-LLaMa is a fine-tuned model specifically designed for traditional mandarin applications. It is built upon the LLaMa 2 architecture and includes a pretraining phase with over 5 billion tokens and fine-tuning with over 490k multi-turn conversational data in Traditional Mandarin.
## Key Features
1. **Traditional Mandarin Support**: The model is fine-tuned to understand and generate text in Traditional Mandarin, making it suitable for Taiwanese culture and related applications.
2. **Instruction-Tuned**: Further fine-tuned on conversational data to offer context-aware and instruction-following responses.
3. **Performance on Vicuna Benchmark**: Taiwan-LLaMa's relative performance on Vicuna Benchmark is measured against models like GPT-4 and ChatGPT. It's particularly optimized for Taiwanese culture.
4. **Flexible Customization**: Advanced options for controlling the model's behavior like system prompt, temperature, top-p, and top-k are available in the demo.
## Model Versions
Different versions of Taiwan-LLaMa are available:
- **Taiwan-LLaMa v2.0 (This demo)**: Cleaner pretraining, Better post-training
- **Taiwan-LLaMa v1.0**: Optimized for Taiwanese Culture
- **Taiwan-LLaMa v0.9**: Partial instruction set
- **Taiwan-LLaMa v0.0**: No Traditional Mandarin pretraining
The models can be accessed from the provided links in the Hugging Face repository.
Try out the demo to interact with Taiwan-LLaMa and experience its capabilities in handling Traditional Mandarin!
"""
LICENSE = """
## Licenses
- Code is licensed under Apache 2.0 License.
- Models are licensed under the LLAMA 2 Community License.
- By using this model, you agree to the terms and conditions specified in the license.
- By using this demo, you agree to share your input utterances with us to improve the model.
## Acknowledgements
Taiwan-LLaMa project acknowledges the efforts of the [Meta LLaMa team](https://github.com/facebookresearch/llama) and [Vicuna team](https://github.com/lm-sys/FastChat) in democratizing large language models.
"""
DEFAULT_SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose."
endpoint_url = os.environ.get("ENDPOINT_URL", "http://127.0.0.1:8080")
client = Client(endpoint_url, timeout=120)
eos_token = "</s>"
MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 1024
max_prompt_length = 4096 - MAX_MAX_NEW_TOKENS - 10
model_name = "yentinglin/Taiwan-LLM-7B-v2.0-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
chatbot = gr.Chatbot()
with gr.Row():
msg = gr.Textbox(
container=False,
show_label=False,
placeholder='Type a message...',
scale=10,
)
submit_button = gr.Button('Submit',
variant='primary',
scale=1,
min_width=0)
with gr.Row():
retry_button = gr.Button('πŸ”„ Retry', variant='secondary')
undo_button = gr.Button('↩️ Undo', variant='secondary')
clear = gr.Button('πŸ—‘οΈ Clear', variant='secondary')
saved_input = gr.State()
with gr.Accordion(label='Advanced options', open=False):
system_prompt = gr.Textbox(label='System prompt',
value=DEFAULT_SYSTEM_PROMPT,
lines=6)
max_new_tokens = gr.Slider(
label='Max new tokens',
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
)
temperature = gr.Slider(
label='Temperature',
minimum=0.1,
maximum=1.0,
step=0.1,
value=0.7,
)
top_p = gr.Slider(
label='Top-p (nucleus sampling)',
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
)
top_k = gr.Slider(
label='Top-k',
minimum=1,
maximum=1000,
step=1,
value=50,
)
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, max_new_tokens, temperature, top_p, top_k, system_prompt):
conv = get_default_conv_template("twllm_v2").copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
conv.system = system_prompt
for user, bot in history:
conv.append_message(roles['human'], user)
conv.append_message(roles["gpt"], bot)
msg = conv.get_prompt()
prompt_tokens = tokenizer.encode(msg)
length_of_prompt = len(prompt_tokens)
if length_of_prompt > max_prompt_length:
msg = tokenizer.decode(prompt_tokens[-max_prompt_length + 1:])
history[-1][1] = ""
for response in client.generate_stream(
msg,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
):
if not response.token.special:
character = response.token.text
history[-1][1] += character
yield history
# After generating the response, store the conversation history in MongoDB
conversation_document = {
"model_name": model_name,
"history": history,
"system_prompt": system_prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
}
conversations_collection.insert_one(conversation_document)
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
fn=bot,
inputs=[
chatbot,
max_new_tokens,
temperature,
top_p,
top_k,
system_prompt,
],
outputs=chatbot
)
submit_button.click(
user, [msg, chatbot], [msg, chatbot], queue=False
).then(
fn=bot,
inputs=[
chatbot,
max_new_tokens,
temperature,
top_p,
top_k,
system_prompt,
],
outputs=chatbot
)
def delete_prev_fn(
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ''
return history, message or ''
def display_input(message: str,
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
history.append((message, ''))
return history
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=bot,
inputs=[
chatbot,
max_new_tokens,
temperature,
top_p,
top_k,
system_prompt,
],
outputs=chatbot,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=msg,
api_name=False,
queue=False,
)
clear.click(lambda: None, None, chatbot, queue=False)
gr.Markdown(LICENSE)
demo.queue(concurrency_count=4, max_size=128)
demo.launch()