|
--- |
|
license: cc-by-nc-sa-4.0 |
|
datasets: |
|
- Blablablab/ALOE |
|
--- |
|
|
|
### Model Description |
|
|
|
The model classifies whether two *appraisals* aligned or not and is trained on [ALOE](https://huggingface.co/datasets/Blablablab/ALOE) dataset. |
|
|
|
**Input:** two appraisals (see `forward` function in `SNN` class) |
|
|
|
**Output:** cosine similarity |
|
|
|
**Model architecture**: Siamese Network + `all-mpnet-base-v2` |
|
|
|
**Developed by:** Jiamin Yang |
|
|
|
### Model Performance |
|
|
|
| F1 | Recall | Precision | |
|
| :--: | :----: | :-------: | |
|
| 0.46 | 0.45 | 0.46 | |
|
|
|
### Getting Started |
|
|
|
```python |
|
import torch |
|
from torch import nn |
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
class SNN(nn.Module): |
|
def __init__(self, model_name): |
|
super(SNN,self).__init__() |
|
self.model = AutoModel.from_pretrained(model_name).to("cuda").train() |
|
self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4) |
|
|
|
def mean_pooling(self, token_embeddings, attention_mask): |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def forward(self, input_ids_a, attention_a, input_ids_b, attention_b): |
|
#encode sentence and get mean pooled sentence representation |
|
encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings |
|
encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0] |
|
|
|
meanPooled1 = self.mean_pooling(encoding1, attention_a) |
|
meanPooled2 = self.mean_pooling(encoding2, attention_b) |
|
|
|
pred = self.cos(meanPooled1, meanPooled2) |
|
return pred |
|
|
|
checkpoint_path = 'your_path_to/empathy-appraisal-alignment.pt' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2') |
|
model = SNN('sentence-transformers/all-mpnet-base-v2').to('cuda') |
|
checkpoint = torch.load(checkpoint_path) |
|
state_dict = checkpoint['model_state_dict'] |
|
|
|
# depend on the version of torch |
|
del state_dict['model.embeddings.position_ids'] |
|
|
|
model.load_state_dict(state_dict) |
|
|
|
# use the model |
|
target = ["I'm so sad that my cat died yesterday."] |
|
observer = ["It's ok to feel sad."] |
|
|
|
target_encodings = tokenizer(target, padding=True, truncation=True) |
|
target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda') |
|
target_attention_mask = torch.LongTensor(target_encodings['attention_mask']).to('cuda') |
|
observer_encodings = tokenizer(observer, padding=True, truncation=True) |
|
observer_input_ids = torch.LongTensor(observer_encodings['input_ids']).to('cuda') |
|
observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask']).to('cuda') |
|
|
|
model.eval() |
|
output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask) |
|
print(output) # [0.5755] |
|
``` |
|
|