Edit model card

language:

  • ru

Model Card for Model ID

Model Details

Model Description

  • Developed by: DeepPavlov team
  • Model type: text generation
  • Language(s) (NLP): Russian
  • License: Openrail
  • Finetuned from model: facebook/mbart-large-50

Uses

Direct Use

from typing import List, TypedDict
from dataclasses import dataclass
from itertools import chain

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch


@dataclass
class H2PersonaChatHyperparametersV1:
    """
    chat_history_pair_length: int - количество пар диалога с конца
    """

    model_name: str = "facebook/bart-base"
    chat_history_pair_length: int = 7

    persona_max_length: int = 14
    chat_max_length: int = 25

    debug_status: int = 0


class PersonaChatDatasetSampleV1(TypedDict):
    """
    persona: List[str] - набор предложений фактов персоны
    history: List[str] - набор предложений истории переписки
    """

    persona: List[str]
    history: List[str]
    sample_id: str


class H2Seq2SeqInferenceSampleDictV1(TypedDict):
    input_ids: List[int]
    attention_mask: List[int]


class H2Seq2SeqInferenceSampleDictV2(TypedDict):
    input_ids: torch.Tensor
    attention_mask: torch.Tensor


def flat_list(list_of_lists: List[List]) -> List:
    return list(chain.from_iterable(list_of_lists))


class H2Seq2SeqInferencePersonaSampleV1:
    def __init__(
        self,
        dataset_sample: PersonaChatDatasetSampleV1,
        tokenizer: AutoTokenizer,
        hyperparameters: H2PersonaChatHyperparametersV1,
    ) -> None:
        self.dataset_sample = dataset_sample
        self.tokenizer = tokenizer
        self.hyperparameters = hyperparameters

    def add_spaces_after(
        self,
        items: List[str],
    ) -> List[str]:
        items = [item + " " for item in items]
        return items

    @property
    def bos_token_id(self):
        if "t5" in self.hyperparameters.model_name:
            return []

        if self.tokenizer.bos_token_id is None:
            return []

        return [self.tokenizer.bos_token_id]

    @property
    def eos_token_id(self):
        if self.tokenizer.eos_token_id is None:
            return []

        return [self.tokenizer.eos_token_id]

    def add_sep_beetween(self, items: List[str], sep=" EOS ") -> List[str]:
        for i in range(1, len(items)):
            items[i] = sep + items[i]

        return items

    def add_spaces_between(self, items: List[str]) -> List[str]:
        items = self.add_spaces_after(items)
        items[-1] = items[-1].strip()
        return items

    def get_sample(self) -> H2Seq2SeqInferenceSampleDictV1:

        dialog_history = self.dataset_sample["history"]
        dialog_history = dialog_history[-self.hyperparameters.chat_history_pair_length * 2 - 1 :]
        dialog_history = self.add_sep_beetween(dialog_history)

        persona = self.dataset_sample["persona"]
        persona = self.add_sep_beetween(
            persona,
            sep=" ",
        )

        KNOWLEDGE_IDS = self.tokenizer.encode(
            " [KNOWLEDGE] ",
            add_special_tokens=False,
        )
        CONTEXT_IDS = self.tokenizer.encode(
            " [CONTEXT]",
            add_special_tokens=False,
        )

        encoded_history = self.tokenizer.batch_encode_plus(
            dialog_history,
            add_special_tokens=False,
            truncation=True,
            max_length=self.hyperparameters.chat_max_length,
        )
        encoded_history = flat_list(encoded_history["input_ids"])

        encoded_persona = self.tokenizer.batch_encode_plus(
            persona,
            add_special_tokens=False,
            truncation=True,
            max_length=self.hyperparameters.persona_max_length,
        )

        encoded_persona = flat_list(encoded_persona["input_ids"])

        input_ids = [
            *self.bos_token_id,
            *CONTEXT_IDS,
            *encoded_history,
            *KNOWLEDGE_IDS,
            *encoded_persona,
            *self.eos_token_id,
        ]

        attention_mask = [1] * len(input_ids)

        return H2Seq2SeqInferenceSampleDictV1(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )


class DialogBotV1:
    def __init__(
        self,
        model: AutoModelForSeq2SeqLM,
        tokenizer: AutoTokenizer,
        hyperparameters: H2PersonaChatHyperparametersV1,
        history: List[str] = None,
        persona: List[str] = None,
        device: str = "cuda",
        shuffle_persona: bool = True,
    ):
        self.model = model

        self.tokenizer = tokenizer
        self.hyperparameters = hyperparameters
        self.device = device
        self.shuffle_persona = shuffle_persona

        self.debug_status = hyperparameters.debug_status

        if history is None:
            self.history = []
        self.history = history

        if persona is None:
            self.persona = []
        self.persona = persona

    def _get_sample(
        self,
        persona: List[str],
        history: List[str],
    ) -> H2Seq2SeqInferenceSampleDictV1:
        dataset_sample = PersonaChatDatasetSampleV1(
            persona=persona,
            history=history,
        )

        sample = H2Seq2SeqInferencePersonaSampleV1(
            tokenizer=self.tokenizer,
            hyperparameters=self.hyperparameters,
            dataset_sample=dataset_sample,
        )
        sample = sample.get_sample()
        print(self.tokenizer.decode(sample['input_ids']))

        for key in sample.keys():
            sample[key] = torch.tensor(sample[key]).unsqueeze(0).to(self.device)

        return sample

    def next_response(
        self,
        **generation_params,
    ) -> str:
        """
        делает предсказание на основе текущей истории
        и персоны
        """

        sample = self._get_sample(
            persona=self.persona,
            history=self.history,
        )
        answer = self.generate_response(
            sample,
            **generation_params,
        )
        answer = self.tokenizer.batch_decode(
            answer,
            skip_special_tokens=True,
        )
        self.history.append(answer[0])
        return answer[0]

    def generate_response(
        self,
        sample: H2Seq2SeqInferenceSampleDictV1,
        **generation_params,
    ):
        """
        generation_params - https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/text_generation
        """
        with torch.no_grad():
            return self.model.generate(
                **sample,
                **generation_params,
            )


# facebook/mbart-large-50
PRETRAINED_MODEL_NAME_OR_PATH = "DeepPavlov/mbart-large-50-ru-persona-chat"

PAIR_DIALOG_HISTORY_LENGTH = 2

# CHAT_MAX_LENGTH for single sentence
CHAT_MAX_LENGTH = 25
# PERSONA_MAX_LENGTH for single sentence
PERSONA_MAX_LENGTH = 19

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSeq2SeqLM.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME_OR_PATH)

if torch.cuda.is_available():
    model.half()

hyperparameters = H2PersonaChatHyperparametersV1(
    chat_history_pair_length=PAIR_DIALOG_HISTORY_LENGTH,
    persona_max_length=PERSONA_MAX_LENGTH,
    chat_max_length=CHAT_MAX_LENGTH,
    model_name=PRETRAINED_MODEL_NAME_OR_PATH,
)


persona = [
     "Я люблю играть с милыми песиками",
    "Я ненавижу лук и броколли"
]

history = [
    "Привет. Ты любишь лук?"
]
            
persona_bot = DialogBotV1(
        model=model,
        tokenizer=tokenizer,
        hyperparameters=hyperparameters,
        history=history,
        persona=persona,
        device=device,
    )

GENERATION_PARAMS = {
    "max_new_tokens": 60,
    "penalty_alpha": 0.15,
    "top_k": 10
}
response = persona_bot.next_response(
    **GENERATION_PARAMS,
)

print(response)

Recommendations

Training Details

Training Data

[More Information Needed]

Preprocessing

  • Initial data was splitted by this script:
def ru_persona_chat_dataset_tranformer_v1(
    initial_dataset_path: str,
    output_folder: str,
) -> None:
    """
        example
            ru_persona_chat_dataset_tranformer_v1(
            initial_dataset_path="./datasets/ru_persona_chat/dialogues.tsv",
            output_folder="./datasets/ru_persona_chat",
    )
    """
    assert initial_dataset_path is not None, "initial_dataset_path is None"
    assert output_folder is not None, "output_folder is None"

    dataset = pd.read_csv(initial_dataset_path, sep="\t")
    split_ratio = int(len(dataset) * 0.95)
    train_dataset = dataset[:split_ratio]
    valid_dataset = dataset[split_ratio:]

    print(f"Dataset lengths: train {len(train_dataset)}, valid {len(valid_dataset)}")
    # save csv files
    train_dataset.to_csv(output_folder + "/train.csv", index=False)
    valid_dataset.to_csv(output_folder + "/valid.csv", index=False)
    print("Datasets saved.")

Evaluation

Metrics

  • BLUEL
  • CharF
  • RougeL
Downloads last month
45