jessicayjm commited on
Commit
6334b7a
1 Parent(s): 1b7fa3a

Add model usage code

Browse files
Files changed (1) hide show
  1. README.md +35 -2
README.md CHANGED
@@ -14,7 +14,7 @@ The model classifies an *appraisal* given a sentence and is trained on [ALOE](ht
14
 
15
  **Output:** logits (in order of labels)
16
 
17
- **Model architecture**: OpenPrompt+RoBERTa
18
 
19
  **Developed by:** Jiamin Yang
20
 
@@ -48,8 +48,12 @@ from openprompt.plms import load_plm
48
  from openprompt.prompts import ManualTemplate
49
  from openprompt.prompts import ManualVerbalizer
50
  from openprompt import PromptForClassification
 
 
51
 
52
- checkpoint_file = 'your_path_to/empathy-appraisal-span.pt'
 
 
53
 
54
  plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
55
  template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
@@ -67,5 +71,34 @@ state_dict = checkpoint['model_state_dict']
67
  del state_dict['prompt_model.plm.roberta.embeddings.position_ids']
68
 
69
  prompt_model.load_state_dict(state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ```
71
 
 
14
 
15
  **Output:** logits (in order of labels)
16
 
17
+ **Model architecture**: OpenPrompt_+RoBERTa
18
 
19
  **Developed by:** Jiamin Yang
20
 
 
48
  from openprompt.prompts import ManualTemplate
49
  from openprompt.prompts import ManualVerbalizer
50
  from openprompt import PromptForClassification
51
+ from openprompt.data_utils import InputExample
52
+ from openprompt import PromptDataLoader
53
 
54
+
55
+ torch.cuda.set_device(1)
56
+ checkpoint_file = 'upload_version/empathy-appraisal-span.pt'
57
 
58
  plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
59
  template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
 
71
  del state_dict['prompt_model.plm.roberta.embeddings.position_ids']
72
 
73
  prompt_model.load_state_dict(state_dict)
74
+
75
+ # use the model
76
+ dataset = [
77
+ InputExample(
78
+ guid = 0,
79
+ text_a = "I am sorry for your loss",
80
+ ),
81
+ InputExample(
82
+ guid = 1,
83
+ text_a = "It's not your fault",
84
+ ),
85
+ ]
86
+
87
+ data_loader = PromptDataLoader(dataset=dataset,
88
+ template=template,
89
+ tokenizer=tokenizer,
90
+ tokenizer_wrapper_class=WrapperClass,
91
+ max_seq_length=512,
92
+ batch_size=2,
93
+ shuffle=False,
94
+ teacher_forcing=False,
95
+ predict_eos_token=False,
96
+ truncate_method='head')
97
+ prompt_model.eval()
98
+ with torch.no_grad():
99
+ for batch in data_loader:
100
+ logits = prompt_model(batch.to('cuda'))
101
+ preds = torch.argmax(logits, dim = -1)
102
+ print(preds) #[8, 5]
103
  ```
104