tinek_sample_model / README.md
RichelieuGVG's picture
Update README.md
0c97014 verified
metadata
datasets:
  - kuznetsoffandrey/sberquad
language:
  - ru
metrics:
  - bleu
  - chrf
base_model:
  - ai-forever/ruT5-base
pipeline_tag: question-answering
library_name: transformers

Проект: Чат-бот с использованием модели ruT5-base для ответов на вопросы

Описание

Этот проект представляет собой систему, которая использует предобученную модель ruT5-base для генерации ответов на вопросы, основанных на предоставленном контексте. Я дообучаю модель на датасете SberQUAD, адаптируя её для задач вопросно-ответного взаимодействия на русском языке.

Датасет

Я использую датасет SberQUAD, который содержит примеры вопросов и ответов на них в контексте различных текстов. Датасет разбит на тренировочные, валидационные и тестовые части.

Архитектура модели

В качестве базовой модели используется ruT5-base — Encoder-Decoder модель, оптимизированная для задач на русском языке. Модель была дополнительно дообучена на кастомных данных для улучшения генерации ответов на основе предоставленного контекста.

Параметры обучения

Для обучения использовались следующие параметры:

output_dir="./models",
optim="adafactor",
num_train_epochs=1, # в идеале 2 эпохи
do_train=True,
gradient_checkpointing=True,
bf16=True,
per_device_train_batch_size=8,
per_device_eval_batch_size=12,
gradient_accumulation_steps=4,
logging_dir="./logs",
report_to="wandb",
logging_steps=10,
save_strategy="steps",
save_steps=5000,
evaluation_strategy="steps",
eval_steps=300,
learning_rate=3e-5,
predict_with_generate=False,
generation_max_length=64

К сожалению, мне не хватило вычислительного времени на Google Collab, поэтому модель была обучена только на одной эпохе с ~1416 шагами.

Результаты обучения

Шаг Loss на валидации Sbleu Chr F Rouge1 Rouge2 Rougel
300 1.025008 18.206400 62.316300 0.110400 0.035200 0.109800
600 1.007530 18.523100 62.564700 0.113300 0.036500 0.112800
900 0.959073 18.869000 63.001700 0.115100 0.035600 0.114600
1200 0.944776 18.656300 62.819800 0.115400 0.035800 0.115000