# coding=utf-8 # Copyright 2020-present the AI Algorithm Research Team. # http://agilesoda.ai # contact@agilesoda.ai # Model import torch import torch.nn.functional as F from transformers import MistralModel from transformers import AutoModel from .config import MistralForEmbeddingConfig class MistralForEmbeddingModel(MistralModel): config_class = MistralForEmbeddingConfig def forward(self, *args, **kwargs): outputs = super().forward(*args, **kwargs) last_hidden_states = outputs.last_hidden_state attention_mask = kwargs.get("attention_mask") left_padding = torch.equal(attention_mask[:, -1], torch.ones(attention_mask.shape[0], dtype=torch.int64)) if left_padding: # -1 is the last token output_embeddings = last_hidden_states[:, -1] else: # find the last token sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] output_embeddings = last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] # It should be performed outside of this model because of batching. # output_embeddings = F.normalize(embeddings, p=2, dim=1) # scores = (embeddings[:2] @ embeddings[2:].T) * 100 # scores = scores.tolist() return output_embeddings AutoModel.register(MistralForEmbeddingConfig, MistralForEmbeddingModel)