Nefertury's picture
shape changes v2.0 (#8)
a059d9b verified
import os
import torch
import gradio as gr
import requests
from typing import List, Dict, Iterator
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from peft import PeftModel
import json
BASE_MODEL_ID = "Tweeties/tweety-7b-tatar-v24a"
ADAPTER_ID = os.getenv("ADAPTER_ID")
YANDEX_API_KEY = os.getenv("YANDEX_API_KEY")
YANDEX_FOLDER_ID= os.getenv("YANDEX_FOLDER_ID")
if not all([ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID]):
raise ValueError("Необходимо установить переменные окружения: ADAPTER_ID, YANDEX_API_KEY, YANDEX_FOLDER_ID")
MAX_NEW_TOKENS = 1024
TEMPERATURE = 1
TOP_P = 0.9
REPETITION_PENALTY = 1.05
SYS_PROMPT_TT = ("Син-цифрлы ярдәмче (ир-ат нәселе). Сине Сбербанк дирекциясенең ESG да уйлап таптылар. Син барлык өлкәләрдә, бигрәк тә Татарстанга кагылышлы өлкәләрдә кызыклы кулланучы эксперты! Ул сезгә бик күп сораулар бирәчәк, ә сезнең эшегез-шәрехләр бирү, кулланучының сорауларына җавап бирү, адымлап киңәшләр, мисаллар бирү һәм, кирәк булганда, кулланучыга аныклаучы сораулар бирү.")
tok = AutoTokenizer.from_pretrained(ADAPTER_ID, use_fast=False)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
print("Применяем LoRA адаптер...")
model = PeftModel.from_pretrained(base, ADAPTER_ID, torch_dtype=torch.float16)
model.config.use_cache = False
model.eval()
print("✅ Модель успешно загружена!")
YANDEX_TRANSLATE_URL = "https://translate.api.cloud.yandex.net/translate/v2/translate"
YANDEX_DETECT_URL = "https://translate.api.cloud.yandex.net/translate/v2/detect"
def detect_language(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "text": text}
try:
resp = requests.post(YANDEX_DETECT_URL, headers=headers, json=payload, timeout=10)
resp.raise_for_status()
return resp.json().get("languageCode", "ru")
except requests.exceptions.RequestException:
return "ru"
def ru2tt(text: str) -> str:
headers = {"Authorization": f"Api-Key {YANDEX_API_KEY}"}
payload = {"folderId": YANDEX_FOLDER_ID, "texts": [text], "sourceLanguageCode": "ru", "targetLanguageCode": "tt"}
try:
resp = requests.post(YANDEX_TRANSLATE_URL, headers=headers, json=payload, timeout=30)
resp.raise_for_status()
return resp.json()["translations"][0]["text"]
except requests.exceptions.RequestException:
return text
def render_prompt(messages: List[Dict[str, str]]) -> str:
return tok.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# --- 4) Стриминговая генерация (без тримминга) ---
@torch.inference_mode()
def generate_tt_reply_stream(messages: List[Dict[str, str]]) -> Iterator[str]:
prompt = render_prompt(messages)
enc = tok(prompt, return_tensors="pt")
enc = {k: v.to(model.device) for k, v in enc.items()}
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = dict(
**enc,
streamer=streamer,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
# temperature=TEMPERATURE,
# top_p=TOP_P,
repetition_penalty=REPETITION_PENALTY,
eos_token_id=tok.eos_token_id,
pad_token_id=tok.pad_token_id,
)
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
acc = ""
for chunk in streamer:
acc += chunk
yield acc
def chat_fn(message: str, ui_history: list, messages_state: List[Dict[str, str]]):
if not messages_state or messages_state[0].get("role") != "system":
messages_state = [{"role": "system", "content": SYS_PROMPT_TT}]
detected = detect_language(message)
user_tt = ru2tt(message) if detected != "tt" else message
messages = messages_state + [{"role": "user", "content": user_tt}]
ui_history = ui_history + [[user_tt, ""]]
last = ""
for partial in generate_tt_reply_stream(messages):
last = partial
ui_history[-1][1] = partial
yield ui_history, messages_state + [
{"role": "user", "content": user_tt},
{"role": "assistant", "content": partial},
]
final_state = messages + [{"role": "assistant", "content": last}]
print("STATE:", json.dumps(final_state, ensure_ascii=False))
with gr.Blocks(
theme=gr.themes.Soft(),
css="""
#chatbot .message.bot,
#chatbot .message.bot .markdown,
#chatbot .message.bot .prose,
#chatbot .message.bot p,
#chatbot .message.bot li,
#chatbot .message.bot pre,
#chatbot .message.bot code {
font-size: 22px !important;
line-height: 1.7 !important;
}
#chatbot .gr-chatbot_message.gr-chatbot_message__bot,
#chatbot .gr-chatbot_message.gr-chatbot_message__bot .gr-chatbot_markdown > *,
#chatbot .gr-chatbot_message--assistant,
#chatbot .gr-chatbot_message--assistant .gr-chatbot_markdown > * {
font-size: 22px !important;
line-height: 1.7 !important;
}
#chatbot .gr-chatbot { font-size: 18px !important; line-height: 1.5; }
#chatbot .gr-chatbot_message { font-size: 18px !important; }
#chatbot .gr-chatbot_markdown > * { font-size: 18px !important; line-height: 1.6; }
#msg textarea { font-size: 24px !important; }
#clear { font-size: 16px !important; }
#title h2 { font-size: 28px !important; }
"""
) as demo:
gr.Markdown("## Татарский чат-бот от команды Сбера", elem_id="title")
messages_state = gr.State([{"role": "system", "content": SYS_PROMPT_TT}])
chatbot = gr.Chatbot(
label="Диалог",
height=500,
bubble_full_width=False,
elem_id="chatbot"
)
msg = gr.Textbox(
label="Хәбәрегезне рус яки татар телендә языгыз",
placeholder="Татарстанның башкаласы нинди шәһәр? / Какая столица Татарстана?",
elem_id="msg"
)
clear = gr.Button("🗑️ Чистарту", elem_id="clear")
msg.submit(
chat_fn,
inputs=[msg, chatbot, messages_state],
outputs=[chatbot, messages_state],
)
msg.submit(lambda: "", None, msg)
def _reset():
return [], [{"role": "system", "content": SYS_PROMPT_TT}]
clear.click(_reset, inputs=None, outputs=[chatbot, messages_state], queue=False)
clear.click(lambda: "", None, msg, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))