amazingvince
commited on
Commit
•
4ebea54
1
Parent(s):
d572741
Update modeling_custom_seq2seq_llm.py
Browse files- modeling_custom_seq2seq_llm.py +27 -27
modeling_custom_seq2seq_llm.py
CHANGED
@@ -1228,33 +1228,33 @@ class CustomSeq2SeqLLM(PreTrainedModel):
|
|
1228 |
torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
|
1229 |
torch.save(cpu_state_dict, torch_filepath)
|
1230 |
|
1231 |
-
@classmethod
|
1232 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1233 |
-
|
1234 |
-
|
1235 |
-
|
1236 |
-
|
1237 |
-
|
1238 |
-
|
1239 |
-
|
1240 |
-
|
1241 |
-
|
1242 |
-
|
1243 |
-
|
1244 |
-
|
1245 |
-
|
1246 |
-
|
1247 |
-
|
1248 |
-
|
1249 |
-
|
1250 |
-
|
1251 |
-
|
1252 |
-
|
1253 |
-
|
1254 |
-
|
1255 |
-
|
1256 |
-
|
1257 |
-
|
1258 |
|
1259 |
class CustomEncoder(nn.Module):
|
1260 |
def __init__(self, config):
|
|
|
1228 |
torch_filepath = os.path.join(save_directory, "pytorch_model.bin")
|
1229 |
torch.save(cpu_state_dict, torch_filepath)
|
1230 |
|
1231 |
+
# @classmethod
|
1232 |
+
# def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1233 |
+
# config = kwargs.pop("config", None)
|
1234 |
+
# state_dict = kwargs.pop("state_dict", None)
|
1235 |
+
|
1236 |
+
# if config is None:
|
1237 |
+
# config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
1238 |
+
|
1239 |
+
# model = cls(config)
|
1240 |
+
|
1241 |
+
# if state_dict is None:
|
1242 |
+
# # Try loading safetensors first
|
1243 |
+
# safe_filepath = os.path.join(pretrained_model_name_or_path, "model.safetensors")
|
1244 |
+
# if os.path.exists(safe_filepath):
|
1245 |
+
# from safetensors.torch import load_file
|
1246 |
+
# state_dict = load_file(safe_filepath)
|
1247 |
+
# else:
|
1248 |
+
# # Fall back to PyTorch format
|
1249 |
+
# torch_filepath = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
|
1250 |
+
# state_dict = torch.load(torch_filepath, map_location="cpu")
|
1251 |
+
|
1252 |
+
# # Handle shared weights
|
1253 |
+
# if config.tie_word_embeddings and "lm_head.weight" not in state_dict:
|
1254 |
+
# state_dict["lm_head.weight"] = state_dict["shared.weight"]
|
1255 |
+
|
1256 |
+
# model.load_state_dict(state_dict)
|
1257 |
+
# return model
|
1258 |
|
1259 |
class CustomEncoder(nn.Module):
|
1260 |
def __init__(self, config):
|