datasets:
- kuznetsoffandrey/sberquad
language:
- ru
metrics:
- bleu
- chrf
base_model:
- ai-forever/ruT5-base
pipeline_tag: question-answering
library_name: transformers
Проект: Чат-бот с использованием модели ruT5-base для ответов на вопросы
Описание
Этот проект представляет собой систему, которая использует предобученную модель ruT5-base для генерации ответов на вопросы, основанных на предоставленном контексте. Я дообучаю модель на датасете SberQUAD, адаптируя её для задач вопросно-ответного взаимодействия на русском языке.
Датасет
Я использую датасет SberQUAD, который содержит примеры вопросов и ответов на них в контексте различных текстов. Датасет разбит на тренировочные, валидационные и тестовые части.
Архитектура модели
В качестве базовой модели используется ruT5-base — Encoder-Decoder модель, оптимизированная для задач на русском языке. Модель была дополнительно дообучена на кастомных данных для улучшения генерации ответов на основе предоставленного контекста.
Параметры обучения
Для обучения использовались следующие параметры:
output_dir="./models",
optim="adafactor",
num_train_epochs=1, # в идеале 2 эпохи
do_train=True,
gradient_checkpointing=True,
bf16=True,
per_device_train_batch_size=8,
per_device_eval_batch_size=12,
gradient_accumulation_steps=4,
logging_dir="./logs",
report_to="wandb",
logging_steps=10,
save_strategy="steps",
save_steps=5000,
evaluation_strategy="steps",
eval_steps=300,
learning_rate=3e-5,
predict_with_generate=False,
generation_max_length=64
К сожалению, мне не хватило вычислительного времени на Google Collab, поэтому модель была обучена только на одной эпохе с ~1416 шагами.
Результаты обучения
Шаг | Loss на валидации | Sbleu | Chr F | Rouge1 | Rouge2 | Rougel |
---|---|---|---|---|---|---|
300 | 1.025008 | 18.206400 | 62.316300 | 0.110400 | 0.035200 | 0.109800 |
600 | 1.007530 | 18.523100 | 62.564700 | 0.113300 | 0.036500 | 0.112800 |
900 | 0.959073 | 18.869000 | 63.001700 | 0.115100 | 0.035600 | 0.114600 |
1200 | 0.944776 | 18.656300 | 62.819800 | 0.115400 | 0.035800 | 0.115000 |