oi-fdibaldassarre
commited on
Upload custom model code
Browse files- modeling_custom.py +78 -0
modeling_custom.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This file contains the custom code needed to make the causal_classification models
|
2 |
+
compatible with huggingface Auto classes.
|
3 |
+
|
4 |
+
tokenizer = AutoTokenizer.from_pretrained("my/repo")
|
5 |
+
model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True)
|
6 |
+
|
7 |
+
classifier = pipeline("text-classification", "my/repo", trust_remote_code=True)
|
8 |
+
"""
|
9 |
+
|
10 |
+
from typing import Optional, Union
|
11 |
+
|
12 |
+
from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
|
13 |
+
import torch
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
|
18 |
+
def load_head(
|
19 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
20 |
+
config: PretrainedConfig,
|
21 |
+
device,
|
22 |
+
) -> torch.nn.Linear:
|
23 |
+
head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False)
|
24 |
+
classification_head = os.path.join(
|
25 |
+
pretrained_model_name_or_path, "classification_head.pth"
|
26 |
+
)
|
27 |
+
head.weight.data = torch.load(classification_head, map_location=device)
|
28 |
+
return head
|
29 |
+
|
30 |
+
|
31 |
+
class CustomModelForSequenceClassification(PreTrainedModel):
|
32 |
+
# Suppress the warning "Some weights were not initialized...You should probably TRAIN this model..."
|
33 |
+
_keys_to_ignore_on_load_missing = ["model.*", "head.*"]
|
34 |
+
|
35 |
+
def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear):
|
36 |
+
super().__init__(config)
|
37 |
+
self.model_backbone = backbone
|
38 |
+
self.head = head
|
39 |
+
|
40 |
+
def forward(self, **kwargs):
|
41 |
+
r = self.model_backbone(**kwargs).logits
|
42 |
+
out_last = r[:, -1].float()
|
43 |
+
logits = self.head(out_last)
|
44 |
+
return {"logits": logits}
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_pretrained(
|
48 |
+
cls,
|
49 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
50 |
+
*model_args,
|
51 |
+
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
|
52 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
53 |
+
ignore_mismatched_sizes: bool = False,
|
54 |
+
force_download: bool = False,
|
55 |
+
local_files_only: bool = False,
|
56 |
+
token: Optional[Union[str, bool]] = None,
|
57 |
+
revision: str = "main",
|
58 |
+
use_safetensors: bool = None,
|
59 |
+
**kwargs,
|
60 |
+
):
|
61 |
+
model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained(
|
62 |
+
pretrained_model_name_or_path,
|
63 |
+
*model_args,
|
64 |
+
config=config,
|
65 |
+
cache_dir=cache_dir,
|
66 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
67 |
+
force_download=force_download,
|
68 |
+
local_files_only=local_files_only,
|
69 |
+
token=token,
|
70 |
+
revision=revision,
|
71 |
+
use_safetensors=use_safetensors,
|
72 |
+
trust_remote_code=True,
|
73 |
+
**kwargs,
|
74 |
+
)
|
75 |
+
device = next(model_backbone.parameters()).device
|
76 |
+
head = load_head(pretrained_model_name_or_path, config, device=device)
|
77 |
+
|
78 |
+
return cls(config, model_backbone, head)
|