sileod commited on
Commit
2e28741
1 Parent(s): 2834991

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -1
README.md CHANGED
@@ -188,6 +188,7 @@ Results:
188
 
189
  For more information, see: [Model Recycling](https://ibm.github.io/model-recycling/)
190
 
 
191
  # Citation
192
 
193
  More details on this [article:](https://arxiv.org/abs/2301.05948)
@@ -200,7 +201,26 @@ More details on this [article:](https://arxiv.org/abs/2301.05948)
200
  year={2023}
201
  }
202
  ```
203
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
 
206
  # Model Card Contact
 
188
 
189
  For more information, see: [Model Recycling](https://ibm.github.io/model-recycling/)
190
 
191
+
192
  # Citation
193
 
194
  More details on this [article:](https://arxiv.org/abs/2301.05948)
 
201
  year={2023}
202
  }
203
  ```
204
+
205
+ # Loading a specific classifier
206
+
207
+ ```
208
+ from torch import nn
209
+
210
+ TASK_NAME = "hh-rlhf"
211
+
212
+ class MultiTask(transformers.DebertaV2ForMultipleChoice):
213
+ def __init__(self, *args, **kwargs):
214
+ super().__init__(*args)
215
+ n=len(self.config.tasks)
216
+ cs=self.config.classifiers_size
217
+ self.Z = nn.Embedding(n,768)
218
+ self.classifiers = nn.ModuleList([torch.nn.Linear(*size) for size in cs])
219
+
220
+ model = MultiTask.from_pretrained("sileod/deberta-v3-base-tasksource-nli",ignore_mismatched_sizes=True)
221
+ task_index = {k:v for v,k in dict(enumerate(model.config.tasks)).items()}[TASK_NAME]
222
+ model.classifier = model.classifiers[task_index] # model is ready for $TASK_NAME !
223
+ ```
224
 
225
 
226
  # Model Card Contact