picekl commited on
Commit
b759dcb
·
verified ·
1 Parent(s): 4d99f29

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +21 -0
model.py CHANGED
@@ -109,6 +109,27 @@ class MultiModalEnsembleC(nn.Module):
109
  with open(config_path, "w") as f:
110
  json.dump(config, f)
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def forward(self, x, y, z):
113
  """
114
  Forward pass of the MultiModalEnsembleC model.
 
109
  with open(config_path, "w") as f:
110
  json.dump(config, f)
111
 
112
+ @classmethod
113
+ def from_pretrained(cls, load_directory):
114
+ """
115
+ Load model weights and configuration from a directory.
116
+
117
+ Args:
118
+ load_directory (str): Path to the directory containing model weights and configuration.
119
+
120
+ Returns:
121
+ MultiModalEnsembleC: A model instance with loaded weights.
122
+ """
123
+ config_path = os.path.join(load_directory, "config.json")
124
+ with open(config_path, "r") as f:
125
+ config = json.load(f)
126
+
127
+ # Create the model instance
128
+ model = cls(num_classes=config["num_classes"])
129
+ model_path = os.path.join(load_directory, "pytorch_model.bin")
130
+ model.load_state_dict(torch.load(model_path))
131
+ return model
132
+
133
  def forward(self, x, y, z):
134
  """
135
  Forward pass of the MultiModalEnsembleC model.