Model Card for Grpp/T5_spell-base
Модель для исправления опечаток на русском языке
Model Details
Model Description
- Language(s) (NLP): ['ru']
- License: mit
- Finetuned from model [optional]: Grpp/rut5-base
Uses
Direct Use
from transformers import pipeline
from transformers import T5ForConditionalGeneration, T5Tokenizer, Text2TextGenerationPipeline
# Создание pipeline для генерации текста
PIPELINE = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer, device=0)
def answer_m(list_texts):
texts = []
for txt in tqdm(list_texts):
texts.append(
PIPELINE(
txt,
max_length=256,
repetition_penalty=1.5,
temperature=0.7,
top_k=50,
num_return_sequences=1
)[0]['generated_text'])
return texts
text = 'нападавше иты кроме того при наадении на отдел уиполицииранение получилаи женщина из гражчданских сообщилон анронимныйистточни агентста тасс со ссылкой на источник пишет что у одногао из преступников быиевзрычатычгтка полицейские потребовали чтобнападавшие останвеились после чего те дотали ножи'
prefix = 'Исправь: '
text_to_model = prefix + text
answer_m([text_to_model])
# ['Нападавшие иты Кроме того, при нападении на отдел полиции ранение получила женщина из гражданских сообщил один аналогичный источник. Агентство ТАСС со ссылкой на источник пишет, что у одного из преступников были взрывчатка: полицейские потребовали, чтобы напавшие остановились после чего те достали ножы.']
Training Details
Training Procedure
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW
from tqdm.auto import tqdm
raw_model = 'Grpp/T5_spell-base' # предобученная модель
DATASET = "Grpp/t5-russian-spell_I" # Введите наазвание название датасета
model = T5ForConditionalGeneration.from_pretrained(raw_model).cuda();
tokenizer = T5Tokenizer.from_pretrained(raw_model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
# Загрузка датасета
new_dataset = load_dataset(DATASET)
model.to('cuda')
_ = model.config
batch_size = 8 # сколько примеров показываем модели за один шаг
report_steps = 1000 # раз в сколько шагов печатаем результат
epochs = 1 # сколько раз мы покажем данные модели
class TextDataset(Dataset):
def __init__(self, tokenizer, pairs):
self.tokenizer = tokenizer
self.pairs = pairs
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
question = self.pairs[idx]['input_text'].replace('Spell correct: ', 'Исправь: ')
answer = self.pairs[idx]['label_text']
source = self.tokenizer(question, padding='max_length', truncation=True, max_length=256, return_tensors='pt')
target = self.tokenizer(answer, padding='max_length', truncation=True, max_length=256, return_tensors='pt')
target.input_ids[target.input_ids == 0] = -100
return source, target
def train_epoch(model, dataloader, optimizer):
model.train()
losses = []
for i, (x, y) in enumerate(tqdm(dataloader)):
optimizer.zero_grad()
outputs = model(
input_ids=x['input_ids'].squeeze().to(model.device),
attention_mask=x['attention_mask'].squeeze().to(model.device),
labels=y['input_ids'].squeeze().to(model.device),
decoder_attention_mask=y['attention_mask'].squeeze().to(model.device),
)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(loss.item())
if i % report_steps == 0:
print('step', i, 'loss', np.mean(losses[-report_steps:]))
return np.mean(losses)
# Создаем датасет и даталоадер
dataset = TextDataset(tokenizer, new_dataset['train'])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Оптимизатор
optimizer = AdamW(model.parameters(), lr=5e-5)
model_name_t5 = 'T5_spell-base'
# Обучение модели
for epoch in range(epochs):
print('EPOCH', epoch + 1)
epoch_loss = train_epoch(model, dataloader, optimizer)
print(f'Epoch {epoch + 1} Loss: {epoch_loss}')
# Сохранение модели после каждой эпохи
print('saving')
model.save_pretrained(model_name_t5)
tokenizer.save_pretrained(model_name_t5)
- Downloads last month
- 10
Inference Providers
NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API:
The model has no library tag.
Model tree for Grpp/T5_spell-base
Base model
Grpp/rut5-base