File size: 3,466 Bytes
c3dfc09 4d87bab 2edf1f0 e143294 c3dfc09 1859ba2 e143294 1859ba2 3d556e6 2edf1f0 e143294 4d87bab 1859ba2 4d87bab 3d556e6 e143294 23be90a 4d87bab e143294 2edf1f0 e143294 2edf1f0 c3dfc09 3d556e6 5b1faf8 2f225f8 e143294 4d87bab 6e9275a 4d87bab 3d556e6 92bf9aa 6e9275a 1859ba2 5b1faf8 3d556e6 5b1faf8 2edf1f0 e143294 5b1faf8 85dd489 3d556e6 6e9275a 1859ba2 85dd489 3d556e6 1859ba2 5b1faf8 85dd489 c3dfc09 133324c 999bf4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
from transformers import pipeline
import torch
import logging
# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Загружаем модель
model_name = "sberbank-ai/rugpt3large_based_on_gpt2"
try:
logger.info(f"Попытка загрузки модели {model_name}...")
generator = pipeline(
"text-generation",
model=model_name,
device=-1, # Используем CPU
framework="pt",
max_length=80, # Уменьшен для стабильности на CPU
truncation=True,
model_kwargs={"torch_dtype": torch.float32}
)
logger.info("Модель успешно загружена.")
except Exception as e:
logger.error(f"Ошибка загрузки модели: {e}")
exit(1)
def respond(message, max_tokens=80, temperature=0.5, top_p=0.7):
# Промпт с акцентом на медицинский ответ
prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]."
try:
logger.info(f"Генерация ответа для: {message}")
outputs = generator(
prompt,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=1,
no_repeat_ngram_size=2 # Предотвращаем повторы
)
response = outputs[0]["generated_text"].replace(prompt, "").strip()
logger.info(f"Ответ сгенерирован: {response}")
# Проверка и форматирование ответа
if "Диагноз:" in response and "Лечение:" in response:
return response
else:
# Если формат не соблюден, извлекаем диагноз и добавляем базовое лечение
diagnosis = response.split(".")[0].strip() if response else "Неизвестно"
return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи."
except Exception as e:
logger.error(f"Ошибка генерации ответа: {e}")
return "Ошибка генерации. Проконсультируйтесь с врачом."
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
gr.Slider(minimum=50, maximum=150, value=80, step=10, label="Макс. токенов"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p")
],
outputs="text",
title="Медицинский чат-бот на базе RuGPT-3 Large",
theme=gr.themes.Soft(),
description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу."
)
if __name__ == "__main__":
demo.launch() |