oi-fdibaldassarre commited on
Commit
e9ffcc0
·
verified ·
1 Parent(s): 8580cad

Upload custom model code

Browse files
Files changed (1) hide show
  1. modeling_custom.py +78 -0
modeling_custom.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This file contains the custom code needed to make the causal_classification models
2
+ compatible with huggingface Auto classes.
3
+
4
+ tokenizer = AutoTokenizer.from_pretrained("my/repo")
5
+ model = AutoModelForSequenceClassification.from_pretrained("my/repo", trust_remote_code=True)
6
+
7
+ classifier = pipeline("text-classification", "my/repo", trust_remote_code=True)
8
+ """
9
+
10
+ from typing import Optional, Union
11
+
12
+ from transformers import PreTrainedModel, AutoModelForCausalLM, PretrainedConfig
13
+ import torch
14
+
15
+ import os
16
+
17
+
18
+ def load_head(
19
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
20
+ config: PretrainedConfig,
21
+ device,
22
+ ) -> torch.nn.Linear:
23
+ head = torch.nn.Linear(config.vocab_size, config.num_labels, bias=False)
24
+ classification_head = os.path.join(
25
+ pretrained_model_name_or_path, "classification_head.pth"
26
+ )
27
+ head.weight.data = torch.load(classification_head, map_location=device)
28
+ return head
29
+
30
+
31
+ class CustomModelForSequenceClassification(PreTrainedModel):
32
+ # Suppress the warning "Some weights were not initialized...You should probably TRAIN this model..."
33
+ _keys_to_ignore_on_load_missing = ["model.*", "head.*"]
34
+
35
+ def __init__(self, config, backbone: torch.nn.Module, head: torch.nn.Linear):
36
+ super().__init__(config)
37
+ self.model_backbone = backbone
38
+ self.head = head
39
+
40
+ def forward(self, **kwargs):
41
+ r = self.model_backbone(**kwargs).logits
42
+ out_last = r[:, -1].float()
43
+ logits = self.head(out_last)
44
+ return {"logits": logits}
45
+
46
+ @classmethod
47
+ def from_pretrained(
48
+ cls,
49
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
50
+ *model_args,
51
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
52
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
53
+ ignore_mismatched_sizes: bool = False,
54
+ force_download: bool = False,
55
+ local_files_only: bool = False,
56
+ token: Optional[Union[str, bool]] = None,
57
+ revision: str = "main",
58
+ use_safetensors: bool = None,
59
+ **kwargs,
60
+ ):
61
+ model_backbone: torch.nn.Module = AutoModelForCausalLM.from_pretrained(
62
+ pretrained_model_name_or_path,
63
+ *model_args,
64
+ config=config,
65
+ cache_dir=cache_dir,
66
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
67
+ force_download=force_download,
68
+ local_files_only=local_files_only,
69
+ token=token,
70
+ revision=revision,
71
+ use_safetensors=use_safetensors,
72
+ trust_remote_code=True,
73
+ **kwargs,
74
+ )
75
+ device = next(model_backbone.parameters()).device
76
+ head = load_head(pretrained_model_name_or_path, config, device=device)
77
+
78
+ return cls(config, model_backbone, head)