Spaces:
Runtime error
Runtime error
Delete custom_model.py
Browse files- custom_model.py +0 -35
custom_model.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
-
from transformers import PreTrainedModel, AutoConfig, AutoModel
|
4 |
-
|
5 |
-
class CustomModel(PreTrainedModel):
|
6 |
-
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
|
7 |
-
|
8 |
-
def __init__(self, config):
|
9 |
-
super().__init__(config)
|
10 |
-
# Implement your model architecture here
|
11 |
-
self.encoder = AutoModel.from_config(config) # Load the base model
|
12 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
13 |
-
|
14 |
-
def forward(self, input_ids, attention_mask=None):
|
15 |
-
# Pass inputs through the encoder
|
16 |
-
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
17 |
-
# Get the pooled output (e.g., CLS token for classification tasks)
|
18 |
-
pooled_output = outputs.last_hidden_state[:, 0, :]
|
19 |
-
# Pass through the classifier
|
20 |
-
logits = self.classifier(pooled_output)
|
21 |
-
return logits
|
22 |
-
|
23 |
-
@classmethod
|
24 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
25 |
-
try:
|
26 |
-
# Load the configuration
|
27 |
-
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
28 |
-
# Initialize the model with the configuration
|
29 |
-
model = cls(config)
|
30 |
-
# Optionally, you can load the state_dict here if needed
|
31 |
-
# model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")))
|
32 |
-
return model
|
33 |
-
except Exception as e:
|
34 |
-
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
|
35 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|