Spaces:
Build error
Build error
#TextAugmentation.py | |
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, MarianMTModel, MarianTokenizer | |
import torch | |
import os | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
class TextAugmentation: | |
def __init__(self, | |
paraphrase_model_name="cointegrated/rut5-base-paraphraser", | |
ru_en_model_name="Helsinki-NLP/opus-mt-ru-en", | |
en_ru_model_name="Helsinki-NLP/opus-mt-en-ru"): | |
# Инициализация модели для перефразирования | |
self.paraphrase_tokenizer = T5Tokenizer.from_pretrained(paraphrase_model_name, legacy=False) | |
self.paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(paraphrase_model_name) | |
# Инициализация моделей для обратного перевода | |
self.ru_en_tokenizer = MarianTokenizer.from_pretrained(ru_en_model_name) | |
self.ru_en_model = MarianMTModel.from_pretrained(ru_en_model_name) | |
self.en_ru_tokenizer = MarianTokenizer.from_pretrained(en_ru_model_name) | |
self.en_ru_model = MarianMTModel.from_pretrained(en_ru_model_name) | |
def paraphrase(self, text, num_return_sequences=1): | |
""" | |
Перефразирование текста с использованием модели. | |
Args: | |
text (str): Исходный текст для перефразирования. | |
num_return_sequences (int): Количество вариантов перефразирования. | |
Returns: | |
list[str]: Список вариантов перефразирования текста. | |
""" | |
inputs = self.paraphrase_tokenizer([text], max_length=512, truncation=True, return_tensors="pt") | |
outputs = self.paraphrase_model.generate( | |
**inputs, | |
max_length=128, | |
num_return_sequences=num_return_sequences, | |
do_sample=True, | |
temperature=1.2, | |
top_k=50, | |
top_p=0.90 | |
) | |
return [self.paraphrase_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
def back_translate(self, text): | |
""" | |
Выполняет обратный перевод текста: русский -> английский -> русский. | |
Args: | |
text (str): Исходный текст для обратного перевода. | |
Returns: | |
str: Текст после обратного перевода. | |
""" | |
# Перевод с русского на английский | |
inputs = self.ru_en_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.ru_en_model.generate(**inputs) | |
translated_text = self.ru_en_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Перевод с английского обратно на русский | |
inputs = self.en_ru_tokenizer(translated_text, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.en_ru_model.generate(**inputs) | |
back_translated_text = self.en_ru_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return back_translated_text | |