Spaces:
Runtime error
Runtime error
Update custom_model.py
Browse files- custom_model.py +12 -3
custom_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
from transformers import PreTrainedModel, AutoConfig
|
4 |
|
5 |
class CustomModel(PreTrainedModel):
|
6 |
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
|
@@ -8,8 +8,18 @@ class CustomModel(PreTrainedModel):
|
|
8 |
def __init__(self, config):
|
9 |
super().__init__(config)
|
10 |
# Implement your model architecture here
|
|
|
11 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
@classmethod
|
14 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
15 |
try:
|
@@ -18,8 +28,7 @@ class CustomModel(PreTrainedModel):
|
|
18 |
# Initialize the model with the configuration
|
19 |
model = cls(config)
|
20 |
# Load the model weights using the transformers library
|
21 |
-
|
22 |
-
model.load_state_dict(state_dict)
|
23 |
return model
|
24 |
except Exception as e:
|
25 |
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
from transformers import PreTrainedModel, AutoConfig, AutoModel
|
4 |
|
5 |
class CustomModel(PreTrainedModel):
|
6 |
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
|
|
|
8 |
def __init__(self, config):
|
9 |
super().__init__(config)
|
10 |
# Implement your model architecture here
|
11 |
+
self.encoder = AutoModel.from_config(config) # Load the base model
|
12 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
13 |
|
14 |
+
def forward(self, input_ids, attention_mask=None):
|
15 |
+
# Pass inputs through the encoder
|
16 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
17 |
+
# Get the pooled output (e.g., CLS token for classification tasks)
|
18 |
+
pooled_output = outputs.last_hidden_state[:, 0, :]
|
19 |
+
# Pass through the classifier
|
20 |
+
logits = self.classifier(pooled_output)
|
21 |
+
return logits
|
22 |
+
|
23 |
@classmethod
|
24 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
25 |
try:
|
|
|
28 |
# Initialize the model with the configuration
|
29 |
model = cls(config)
|
30 |
# Load the model weights using the transformers library
|
31 |
+
model.encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
|
|
32 |
return model
|
33 |
except Exception as e:
|
34 |
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
|