Update README.md
Browse files
README.md
CHANGED
@@ -270,25 +270,7 @@ More details on this [article:](https://arxiv.org/abs/2301.05948)
|
|
270 |
```
|
271 |
|
272 |
# Loading a specific classifier
|
273 |
-
|
274 |
-
Classifiers for all tasks available.
|
275 |
-
```python
|
276 |
-
from torch import nn
|
277 |
-
|
278 |
-
TASK_NAME = "hh-rlhf"
|
279 |
-
|
280 |
-
class MultiTask(transformers.DebertaV2ForMultipleChoice):
|
281 |
-
def __init__(self, *args, **kwargs):
|
282 |
-
super().__init__(*args)
|
283 |
-
n=len(self.config.tasks)
|
284 |
-
cs=self.config.classifiers_size
|
285 |
-
self.Z = nn.Embedding(n,768)
|
286 |
-
self.classifiers = nn.ModuleList([torch.nn.Linear(*size) for size in cs])
|
287 |
-
|
288 |
-
model = MultiTask.from_pretrained("sileod/deberta-v3-base-tasksource-nli",ignore_mismatched_sizes=True)
|
289 |
-
task_index = {k:v for v,k in dict(enumerate(model.config.tasks)).items()}[TASK_NAME]
|
290 |
-
model.classifier = model.classifiers[task_index] # model is ready for $TASK_NAME ! (RLHF) !
|
291 |
-
```
|
292 |
|
293 |
|
294 |
# Model Card Contact
|
|
|
270 |
```
|
271 |
|
272 |
# Loading a specific classifier
|
273 |
+
Classifiers for all tasks available. See https://huggingface.co/sileod/deberta-v3-base-tasksource-adapters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
|
275 |
|
276 |
# Model Card Contact
|