amazingvince commited on
Commit
4ebea54
1 Parent(s): d572741

Update modeling_custom_seq2seq_llm.py

Browse files
Files changed (1) hide show
  1. modeling_custom_seq2seq_llm.py +27 -27
modeling_custom_seq2seq_llm.py CHANGED
@@ -1228,33 +1228,33 @@ class CustomSeq2SeqLLM(PreTrainedModel):
1228
  torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
1229
  torch.save(cpu_state_dict, torch_filepath)
1230
 
1231
- @classmethod
1232
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1233
- config = kwargs.pop("config", None)
1234
- state_dict = kwargs.pop("state_dict", None)
1235
-
1236
- if config is None:
1237
- config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1238
-
1239
- model = cls(config)
1240
-
1241
- if state_dict is None:
1242
- # Try loading safetensors first
1243
- safe_filepath = os.path.join(pretrained_model_name_or_path, "model.safetensors")
1244
- if os.path.exists(safe_filepath):
1245
- from safetensors.torch import load_file
1246
- state_dict = load_file(safe_filepath)
1247
- else:
1248
- # Fall back to PyTorch format
1249
- torch_filepath = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
1250
- state_dict = torch.load(torch_filepath, map_location="cpu")
1251
-
1252
- # Handle shared weights
1253
- if config.tie_word_embeddings and "lm_head.weight" not in state_dict:
1254
- state_dict["lm_head.weight"] = state_dict["shared.weight"]
1255
-
1256
- model.load_state_dict(state_dict)
1257
- return model
1258
 
1259
  class CustomEncoder(nn.Module):
1260
  def __init__(self, config):
 
1228
  torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
1229
  torch.save(cpu_state_dict, torch_filepath)
1230
 
1231
+ # @classmethod
1232
+ # def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
1233
+ # config = kwargs.pop("config", None)
1234
+ # state_dict = kwargs.pop("state_dict", None)
1235
+
1236
+ # if config is None:
1237
+ # config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1238
+
1239
+ # model = cls(config)
1240
+
1241
+ # if state_dict is None:
1242
+ # # Try loading safetensors first
1243
+ # safe_filepath = os.path.join(pretrained_model_name_or_path, "model.safetensors")
1244
+ # if os.path.exists(safe_filepath):
1245
+ # from safetensors.torch import load_file
1246
+ # state_dict = load_file(safe_filepath)
1247
+ # else:
1248
+ # # Fall back to PyTorch format
1249
+ # torch_filepath = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
1250
+ # state_dict = torch.load(torch_filepath, map_location="cpu")
1251
+
1252
+ # # Handle shared weights
1253
+ # if config.tie_word_embeddings and "lm_head.weight" not in state_dict:
1254
+ # state_dict["lm_head.weight"] = state_dict["shared.weight"]
1255
+
1256
+ # model.load_state_dict(state_dict)
1257
+ # return model
1258
 
1259
  class CustomEncoder(nn.Module):
1260
  def __init__(self, config):