Tom Aarsen commited on
Commit
c7dfa27
·
1 Parent(s): 0cf3c13

Use snapshot_download to avoid PEFT Windows issue

Browse files
Files changed (1) hide show
  1. splade.py +17 -9
splade.py CHANGED
@@ -48,19 +48,29 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
48
 
49
  @classmethod
50
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 
51
  from peft import PeftConfig, PeftModel
52
 
53
  token = kwargs.get("token")
54
- try:
55
- peft_config = PeftConfig.from_pretrained(
56
- pretrained_model_name_or_path, subfolder=ADAPTER_SUBFOLDER, token=token
 
 
 
 
 
 
 
 
57
  )
58
- except Exception:
59
- peft_config = None
60
 
61
- if peft_config is None:
62
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
63
 
 
 
64
  # Use provided splade config (has is_causal=False) or load it from the adapter repo
65
  config = kwargs.pop("config", None)
66
  if config is None or not isinstance(config, PretrainedConfig):
@@ -73,9 +83,7 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
73
  **kwargs,
74
  )
75
 
76
- return PeftModel.from_pretrained(
77
- base_model, pretrained_model_name_or_path, subfolder=ADAPTER_SUBFOLDER, token=token
78
- )
79
 
80
 
81
  class SpladeConfig(PretrainedConfig):
 
48
 
49
  @classmethod
50
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
51
+ from huggingface_hub import snapshot_download
52
  from peft import PeftConfig, PeftModel
53
 
54
  token = kwargs.get("token")
55
+
56
+ # Resolve the adapter to a local path before handing it to PEFT.
57
+ # PEFT's `subfolder=` kwarg uses `os.path.join` on Windows, producing
58
+ # backslashed hub paths that break the safetensors-vs-bin fallback.
59
+ if os.path.isdir(pretrained_model_name_or_path):
60
+ adapter_path = os.path.join(pretrained_model_name_or_path, ADAPTER_SUBFOLDER)
61
+ else:
62
+ local_repo = snapshot_download(
63
+ pretrained_model_name_or_path,
64
+ allow_patterns=[f"{ADAPTER_SUBFOLDER}/*"],
65
+ token=token,
66
  )
67
+ adapter_path = os.path.join(local_repo, ADAPTER_SUBFOLDER)
 
68
 
69
+ if not os.path.isfile(os.path.join(adapter_path, "adapter_config.json")):
70
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
71
 
72
+ peft_config = PeftConfig.from_pretrained(adapter_path, token=token)
73
+
74
  # Use provided splade config (has is_causal=False) or load it from the adapter repo
75
  config = kwargs.pop("config", None)
76
  if config is None or not isinstance(config, PretrainedConfig):
 
83
  **kwargs,
84
  )
85
 
86
+ return PeftModel.from_pretrained(base_model, adapter_path, token=token)
 
 
87
 
88
 
89
  class SpladeConfig(PretrainedConfig):