File size: 3,466 Bytes
c3dfc09
4d87bab
2edf1f0
e143294
c3dfc09
1859ba2
e143294
 
 
1859ba2
3d556e6
2edf1f0
e143294
4d87bab
 
 
1859ba2
4d87bab
3d556e6
e143294
23be90a
4d87bab
e143294
2edf1f0
e143294
2edf1f0
c3dfc09
3d556e6
 
5b1faf8
2f225f8
e143294
4d87bab
6e9275a
4d87bab
 
 
 
3d556e6
 
92bf9aa
6e9275a
1859ba2
5b1faf8
 
 
 
 
3d556e6
5b1faf8
 
2edf1f0
e143294
5b1faf8
85dd489
 
 
 
 
3d556e6
6e9275a
1859ba2
85dd489
 
3d556e6
1859ba2
5b1faf8
85dd489
c3dfc09
133324c
999bf4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import pipeline
import torch
import logging

# Настройка логирования
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Загружаем модель
model_name = "sberbank-ai/rugpt3large_based_on_gpt2"
try:
    logger.info(f"Попытка загрузки модели {model_name}...")
    generator = pipeline(
        "text-generation",
        model=model_name,
        device=-1,  # Используем CPU
        framework="pt",
        max_length=80,  # Уменьшен для стабильности на CPU
        truncation=True,
        model_kwargs={"torch_dtype": torch.float32}
    )
    logger.info("Модель успешно загружена.")
except Exception as e:
    logger.error(f"Ошибка загрузки модели: {e}")
    exit(1)

def respond(message, max_tokens=80, temperature=0.5, top_p=0.7):
    # Промпт с акцентом на медицинский ответ
    prompt = f"Вы медицинский чат-бот. Пользователь говорит: '{message}'. Дайте краткий ответ только с диагнозом и лечением на русском языке в формате: Диагноз: [диагноз]. Лечение: [лечение]."
    try:
        logger.info(f"Генерация ответа для: {message}")
        outputs = generator(
            prompt,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            num_return_sequences=1,
            no_repeat_ngram_size=2  # Предотвращаем повторы
        )
        response = outputs[0]["generated_text"].replace(prompt, "").strip()
        logger.info(f"Ответ сгенерирован: {response}")

        # Проверка и форматирование ответа
        if "Диагноз:" in response and "Лечение:" in response:
            return response
        else:
            # Если формат не соблюден, извлекаем диагноз и добавляем базовое лечение
            diagnosis = response.split(".")[0].strip() if response else "Неизвестно"
            return f"Диагноз: {diagnosis}. Лечение: Обратитесь к врачу для точной помощи."
    except Exception as e:
        logger.error(f"Ошибка генерации ответа: {e}")
        return "Ошибка генерации. Проконсультируйтесь с врачом."

demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Ваше сообщение", placeholder="Опишите симптомы (например, 'Болит горло')..."),
        gr.Slider(minimum=50, maximum=150, value=80, step=10, label="Макс. токенов"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.5, label="Температура"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top-p")
    ],
    outputs="text",
    title="Медицинский чат-бот на базе RuGPT-3 Large",
    theme=gr.themes.Soft(),
    description="Введите симптомы, и чат-бот предложит диагноз и лечение. Для точной помощи обратитесь к врачу."
)

if __name__ == "__main__":
    demo.launch()