--- license: cc-by-nc-sa-4.0 datasets: - Blablablab/ALOE --- ### Model Description The model classifies whether two *appraisals* aligned or not and is trained on [ALOE](https://huggingface.co/datasets/Blablablab/ALOE) dataset. **Input:** two appraisals (see `forward` function in `SNN` class) **Output:** cosine similarity **Model architecture**: Siamese Network + `all-mpnet-base-v2` **Developed by:** Jiamin Yang ### Model Performance | F1 | Recall | Precision | | :--: | :----: | :-------: | | 0.46 | 0.45 | 0.46 | ### Getting Started ```python import torch from torch import nn from transformers import AutoModel, AutoTokenizer class SNN(nn.Module): def __init__(self, model_name): super(SNN,self).__init__() self.model = AutoModel.from_pretrained(model_name).to("cuda").train() self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4) def mean_pooling(self, token_embeddings, attention_mask): input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def forward(self, input_ids_a, attention_a, input_ids_b, attention_b): #encode sentence and get mean pooled sentence representation encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0] meanPooled1 = self.mean_pooling(encoding1, attention_a) meanPooled2 = self.mean_pooling(encoding2, attention_b) pred = self.cos(meanPooled1, meanPooled2) return pred checkpoint_path = 'your_path_to/empathy-appraisal-alignment.pt' tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') model = SNN('sentence-transformers/all-mpnet-base-v2').to('cuda') checkpoint = torch.load(checkpoint_path) state_dict = checkpoint['model_state_dict'] # depend on the version of torch del state_dict['model.embeddings.position_ids'] model.load_state_dict(state_dict) # use the model target = ["I'm so sad that my cat died yesterday."] observer = ["It's ok to feel sad."] target_encodings = tokenizer(target, padding=True, truncation=True) target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda') target_attention_mask = torch.LongTensor(target_encodings['attention_mask']).to('cuda') observer_encodings = tokenizer(observer, padding=True, truncation=True) observer_input_ids = torch.LongTensor(observer_encodings['input_ids']).to('cuda') observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask']).to('cuda') model.eval() output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask) print(output) # [0.5755] ```