|
--- |
|
license: cc-by-nc-sa-4.0 |
|
datasets: |
|
- Blablablab/ALOE |
|
--- |
|
|
|
### Model Description |
|
|
|
The model classifies an *appraisal* given a sentence and is trained on [ALOE](https://huggingface.co/datasets/Blablablab/ALOE) dataset. |
|
|
|
**Input:** a sentence |
|
|
|
**Labels:** No Label, Pleasantness, Anticipated Effort, Certainty, Objective Experience, Self-Other Agency, Situational Control, Advice, Trope |
|
|
|
**Output:** logits (in order of labels) |
|
|
|
**Model architecture**: OpenPrompt+RoBERTa |
|
|
|
**Developed by:** Jiamin Yang |
|
|
|
### Model Performance |
|
|
|
##### Overall performance |
|
|
|
| Macro-F1 | Recall | Precision | |
|
| :------: | :----: | :-------: | |
|
| 0.56 | 0.57 | 0.58 | |
|
|
|
##### Per-label performance |
|
|
|
| Label | Recall | Precision | |
|
| -------------------- | :----: | :-------: | |
|
| No Label | 0.34 | 0.64 | |
|
| Pleasantness | 0.69 | 0.54 | |
|
| Anticipated Effort | 0.46 | 0.46 | |
|
| Certainty | 0.58 | 0.47 | |
|
| Objective Experience | 0.58 | 0.69 | |
|
| Self-Other Agency | 0.62 | 0.55 | |
|
| Situational Control | 0.31 | 0.55 | |
|
| Advice | 0.72 | 0.66 | |
|
| Trope | 0.80 | 0.67 | |
|
|
|
### Getting Started |
|
|
|
```python |
|
import torch |
|
from openprompt.plms import load_plm |
|
from openprompt.prompts import ManualTemplate |
|
from openprompt.prompts import ManualVerbalizer |
|
from openprompt import PromptForClassification |
|
|
|
checkpoint_file = 'your_path_to/empathy-appraisal-span.pt' |
|
|
|
plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large') |
|
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.' |
|
template = ManualTemplate(tokenizer=tokenizer, text=template_text) |
|
|
|
num_classes = 9 |
|
label_words = [['No Label'], ['Pleasantness'], ['Anticipated Effort'], ['Certainty'], ['Objective Experience'], ['Self-Other Agency'], ['Situational Control'], ['Advice'], ['Trope']] |
|
verbalizer = ManualVerbalizer(tokenizer, num_classes=num_classes, label_words=label_words) |
|
prompt_model = PromptForClassification(plm=plm,template=template, verbalizer=verbalizer, freeze_plm=False).to('cuda') |
|
|
|
checkpoint = torch.load(checkpoint_file) |
|
state_dict = checkpoint['model_state_dict'] |
|
|
|
# depend on the version of torch |
|
del state_dict['prompt_model.plm.roberta.embeddings.position_ids'] |
|
|
|
prompt_model.load_state_dict(state_dict) |
|
``` |
|
|
|
|