wannaphong's picture
Update app.py
ed433b4 verified
raw
history blame contribute delete
No virus
2.6 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
from threading import Thread
from typing import Iterator
model_id = "mistralai/Mistral-Nemo-Instruct-2407"
MAX_INPUT_TOKEN_LENGTH = 4096
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
load_in_8bit=True
)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9
) -> Iterator[str]:
conversation = [{"role": "system", "content": "You are helpful assistant. Your answer are Thai language."}]
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_beams=1
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Set up Gradio interface
iface = gr.ChatInterface(
generate,
chatbot=gr.Chatbot(height=600),
textbox=gr.Textbox(placeholder="Enter your message here...", container=False, scale=7),
title="Chat with Mistral Nemo",
description="This is a chat interface for the Mistral Nemo model. Ask questions and get answers!",
retry_btn="Retry",
undo_btn="Undo Last",
clear_btn="Clear",
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Maximum number of new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
# Launch the interface
iface.launch()