Overview

Cross-encoder for russian language. Primarily trained for RAG purposes. Take two strings, assess if they are related (question and answer pair).

Usage

import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

!wget https://huggingface.co/GrigoryT22/cross-encoder-ru/resolve/main/model.pt  # or simply load the file via browser

model = Model()  # copy-past class code (see below) and run it
model.load_state_dict(torch.load('./model.pt'), strict=False)  # path to downloaded file with the model
# missing_keys=['labse.embeddings.position_ids'] - this is [OK](https://github.com/huggingface/transformers/issues/16353) 

string_1 = """
Компания судится с артистом
""".strip()

string_2 = """
По заявлению инвесторов, компания знала о рисках заключения подобного контракта задолго до антисемитских высказываний Уэста, 
которые он озвучил в октябре 2022 года. Однако, несмотря на то, что Adidas прекратил сотрудничество с артистом, 
избежать судебного разбирательства не удалось. После расторжения контракта с рэпером компания потеряет 1,3 миллиарда долларов.
""".strip()

model([
      [string_1, string_2]
      ])
# should be something like this --->>> tensor([[-4.0403,  3.8442]], grad_fn=<AddmmBackward0>)
# model is pretty sure that these two strings are related, second number is bigger (logits for binary classifications, batch size one in this case)

Model class

class Model(nn.Module):
    """
    labse - base bert-like model
    from labse I use pooler layer as input
    then classification head - binary classification to predict if this pair is TRUE question-answer
    """
    def __init__(self):
        super().__init__()
        self.labse_config = AutoConfig.from_pretrained('cointegrated/LaBSE-en-ru')
        self.labse  = AutoModel.from_config(self.labse_config)
        self.tokenizer = AutoTokenizer.from_pretrained('cointegrated/LaBSE-en-ru')
        self.cls = nn.Sequential(OrderedDict(
                                                  [
                                                    ('dropout_in', torch.nn.Dropout(.0)),
                                                    ('layernorm_in' , nn.LayerNorm(768, eps=1e-05)),

                                                    ('fc_1' , nn.Linear(768, 768 * 2)),
                                                    ('act_1' , nn.GELU()),
                                                    ('layernorm_1' , nn.LayerNorm(768 * 2, eps=1e-05)),

                                                    ('fc_2' , nn.Linear(768 * 2, 768 * 2)),
                                                    ('act_2' , nn.GELU()),
                                                    ('layernorm_2' , nn.LayerNorm(768 * 2, eps=1e-05)),

                                                    ('fc_3' , nn.Linear(768 * 2, 768)),
                                                    ('act_3' , nn.GELU()),
                                                    ('layernorm_3' , nn.LayerNorm(768, eps=1e-05)),

                                                    ('fc_4' , nn.Linear(768, 256)),
                                                    ('act_4' , nn.GELU()),
                                                    ('layernorm_4' , nn.LayerNorm(256, eps=1e-05)),

                                                    ('fc_5' , nn.Linear(256, 2, bias=True)),
                                                  ]
                                    ))
    def forward(self, text):
        token = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt').to(device)
        model_output = self.labse(**token)
        result = self.cls(model_output.pooler_output)
        return result
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.