File size: 2,867 Bytes
496cffd f798a31 496cffd 713cf03 496cffd 713cf03 496cffd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
---
license: cc-by-nc-sa-4.0
datasets:
- Blablablab/ALOE
---
### Model Description
The model classifies whether two *appraisals* aligned or not and is trained on [ALOE](https://huggingface.co/datasets/Blablablab/ALOE) dataset.
**Input:** two appraisals (see `forward` function in `SNN` class)
**Output:** cosine similarity
**Model architecture**: Siamese Network + `all-mpnet-base-v2`
**Developed by:** Jiamin Yang
### Model Performance
| F1 | Recall | Precision |
| :--: | :----: | :-------: |
| 0.46 | 0.45 | 0.46 |
### Getting Started
```python
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
class SNN(nn.Module):
def __init__(self, model_name):
super(SNN,self).__init__()
self.model = AutoModel.from_pretrained(model_name).to("cuda").train()
self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4)
def mean_pooling(self, token_embeddings, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def forward(self, input_ids_a, attention_a, input_ids_b, attention_b):
#encode sentence and get mean pooled sentence representation
encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings
encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0]
meanPooled1 = self.mean_pooling(encoding1, attention_a)
meanPooled2 = self.mean_pooling(encoding2, attention_b)
pred = self.cos(meanPooled1, meanPooled2)
return pred
checkpoint_path = 'your_path_to/empathy-appraisal-alignment.pt'
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = SNN('sentence-transformers/all-mpnet-base-v2').to('cuda')
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['model_state_dict']
# depend on the version of torch
del state_dict['model.embeddings.position_ids']
model.load_state_dict(state_dict)
# use the model
target = ["I'm so sad that my cat died yesterday."]
observer = ["It's ok to feel sad."]
target_encodings = tokenizer(target, padding=True, truncation=True)
target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
target_attention_mask = torch.LongTensor(target_encodings['attention_mask']).to('cuda')
observer_encodings = tokenizer(observer, padding=True, truncation=True)
observer_input_ids = torch.LongTensor(observer_encodings['input_ids']).to('cuda')
observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask']).to('cuda')
model.eval()
output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
print(output) # [0.5755]
```
|