0x7o's picture
Update app.py
09c998a verified
raw
history blame
2.8 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
from threading import Thread
from typing import Iterator
model_id = "mistralai/Mistral-Nemo-Instruct-2407"
MAX_INPUT_TOKEN_LENGTH = 4096
# Загрузка токенизатора и модели
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
load_in_8bit=True
)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_beams=1
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Настройка интерфейса Gradio
iface = gr.ChatInterface(
predict,
chatbot=gr.Chatbot(height=600),
textbox=gr.Textbox(placeholder="Введите ваше сообщение здесь...", container=False, scale=7),
title="Чат с Aeonium v1.1",
description="Это чат-интерфейс для модели Aeonium v1.1 Chat 4B. Задавайте вопросы и получайте ответы!",
theme="soft",
retry_btn="Повторить",
undo_btn="Отменить последнее",
clear_btn="Очистить",
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Максимальное количество новых токенов"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Температура"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
# Запуск интерфейса
iface.launch()