sapthesh commited on
Commit
ccfaaf5
Β·
verified Β·
1 Parent(s): 0e79c88

Update custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +12 -3
custom_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import PreTrainedModel, AutoConfig
4
 
5
  class CustomModel(PreTrainedModel):
6
  config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
@@ -8,8 +8,18 @@ class CustomModel(PreTrainedModel):
8
  def __init__(self, config):
9
  super().__init__(config)
10
  # Implement your model architecture here
 
11
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
12
 
 
 
 
 
 
 
 
 
 
13
  @classmethod
14
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
15
  try:
@@ -18,8 +28,7 @@ class CustomModel(PreTrainedModel):
18
  # Initialize the model with the configuration
19
  model = cls(config)
20
  # Load the model weights using the transformers library
21
- state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location="cpu")
22
- model.load_state_dict(state_dict)
23
  return model
24
  except Exception as e:
25
  print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoConfig, AutoModel
4
 
5
  class CustomModel(PreTrainedModel):
6
  config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
 
8
  def __init__(self, config):
9
  super().__init__(config)
10
  # Implement your model architecture here
11
+ self.encoder = AutoModel.from_config(config) # Load the base model
12
  self.classifier = nn.Linear(config.hidden_size, config.num_labels)
13
 
14
+ def forward(self, input_ids, attention_mask=None):
15
+ # Pass inputs through the encoder
16
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
17
+ # Get the pooled output (e.g., CLS token for classification tasks)
18
+ pooled_output = outputs.last_hidden_state[:, 0, :]
19
+ # Pass through the classifier
20
+ logits = self.classifier(pooled_output)
21
+ return logits
22
+
23
  @classmethod
24
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
25
  try:
 
28
  # Initialize the model with the configuration
29
  model = cls(config)
30
  # Load the model weights using the transformers library
31
+ model.encoder = AutoModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
 
32
  return model
33
  except Exception as e:
34
  print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")