Tom Aarsen commited on
Commit
0cf3c13
·
1 Parent(s): ebcd7f4

Move adapters back to LoRA to avoid inconvenient auto-PEFT trigger

Browse files
adapter_config.json → lora/adapter_config.json RENAMED
File without changes
adapter_model.safetensors → lora/adapter_model.safetensors RENAMED
File without changes
modules.json CHANGED
@@ -3,7 +3,7 @@
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
- "type": "splade.SpladeCodeMLMTransformer"
7
  },
8
  {
9
  "idx": 1,
 
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
+ "type": "sentence_transformers.sparse_encoder.models.MLMTransformer"
7
  },
8
  {
9
  "idx": 1,
splade.py CHANGED
@@ -3,19 +3,27 @@ Compared to standard Qwen3, we're using bidirectional attention and not causal a
3
  with `is_causal=False` in the config.
4
 
5
  This file supports two loading paths:
6
- 1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via SpladeCodeMLMTransformer -> AutoModelForMaskedLM -> Qwen3ForCausalLM
7
  2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
8
 
9
- The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B; `Qwen3ForCausalLM.from_pretrained`
10
- loads the base model and applies the adapter.
11
  """
12
 
 
 
13
  import torch
14
  from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
15
  from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
16
  from transformers.utils import is_flash_attn_2_available
17
  from .utils import prepare_tokenizer, splade_max, similarity, encode
18
 
 
 
 
 
 
 
19
 
20
  class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
21
  def tie_weights(self, *args, **kwargs):
@@ -42,9 +50,10 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
42
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
43
  from peft import PeftConfig, PeftModel
44
 
 
45
  try:
46
  peft_config = PeftConfig.from_pretrained(
47
- pretrained_model_name_or_path, token=kwargs.get("token")
48
  )
49
  except Exception:
50
  peft_config = None
@@ -55,12 +64,7 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
55
  # Use provided splade config (has is_causal=False) or load it from the adapter repo
56
  config = kwargs.pop("config", None)
57
  if config is None or not isinstance(config, PretrainedConfig):
58
- config = AutoConfig.from_pretrained(
59
- pretrained_model_name_or_path, token=kwargs.get("token")
60
- )
61
-
62
- # We apply the adapter manually below, so drop any auto-PEFT hints to avoid double loading
63
- kwargs.pop("adapter_kwargs", None)
64
 
65
  base_model = super().from_pretrained(
66
  peft_config.base_model_name_or_path,
@@ -70,7 +74,7 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
70
  )
71
 
72
  return PeftModel.from_pretrained(
73
- base_model, pretrained_model_name_or_path, token=kwargs.get("token")
74
  )
75
 
76
 
@@ -128,7 +132,7 @@ class Splade(PreTrainedModel):
128
  )
129
 
130
  def save_pretrained(self, save_directory, *args, **kwargs):
131
- self.model.save_pretrained(save_directory)
132
  self.config.save_pretrained(save_directory)
133
 
134
  @classmethod
@@ -166,19 +170,3 @@ class Splade(PreTrainedModel):
166
 
167
 
168
  __all__ = ["Qwen3ForCausalLM", "Splade"]
169
-
170
-
171
- # Override ST's `_load_config` to return our `Qwen3Config` (with `auto_map`)
172
- # instead of a `PeftConfig`, so hub-path loads route to `splade.Qwen3ForCausalLM`
173
- # instead of failing in `AutoModelForMaskedLM`. The LoRA is still applied by
174
- # transformers' built-in PEFT path.
175
- try:
176
- from sentence_transformers.sparse_encoder.models import MLMTransformer
177
-
178
- class SpladeCodeMLMTransformer(MLMTransformer):
179
- def _load_config(self, model_name_or_path, backend, config_kwargs):
180
- return AutoConfig.from_pretrained(model_name_or_path, **config_kwargs), False
181
-
182
- __all__.append("SpladeCodeMLMTransformer")
183
- except ImportError:
184
- pass
 
3
  with `is_causal=False` in the config.
4
 
5
  This file supports two loading paths:
6
+ 1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
7
  2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
8
 
9
+ The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B in the `lora/` subfolder;
10
+ `Qwen3ForCausalLM.from_pretrained` loads the base model and applies the adapter.
11
  """
12
 
13
+ import os
14
+
15
  import torch
16
  from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
17
  from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
18
  from transformers.utils import is_flash_attn_2_available
19
  from .utils import prepare_tokenizer, splade_max, similarity, encode
20
 
21
+ # The adapter lives in this subfolder rather than at the repo root so that
22
+ # `find_adapter_config_file` doesn't trigger transformers' auto-PEFT path,
23
+ # which would otherwise redirect hub loads to `Qwen/Qwen3-8B` and lose the
24
+ # `auto_map` routing to the classes in this file.
25
+ ADAPTER_SUBFOLDER = "lora"
26
+
27
 
28
  class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
29
  def tie_weights(self, *args, **kwargs):
 
50
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
51
  from peft import PeftConfig, PeftModel
52
 
53
+ token = kwargs.get("token")
54
  try:
55
  peft_config = PeftConfig.from_pretrained(
56
+ pretrained_model_name_or_path, subfolder=ADAPTER_SUBFOLDER, token=token
57
  )
58
  except Exception:
59
  peft_config = None
 
64
  # Use provided splade config (has is_causal=False) or load it from the adapter repo
65
  config = kwargs.pop("config", None)
66
  if config is None or not isinstance(config, PretrainedConfig):
67
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, token=token)
 
 
 
 
 
68
 
69
  base_model = super().from_pretrained(
70
  peft_config.base_model_name_or_path,
 
74
  )
75
 
76
  return PeftModel.from_pretrained(
77
+ base_model, pretrained_model_name_or_path, subfolder=ADAPTER_SUBFOLDER, token=token
78
  )
79
 
80
 
 
132
  )
133
 
134
  def save_pretrained(self, save_directory, *args, **kwargs):
135
+ self.model.save_pretrained(os.path.join(save_directory, ADAPTER_SUBFOLDER))
136
  self.config.save_pretrained(save_directory)
137
 
138
  @classmethod
 
170
 
171
 
172
  __all__ = ["Qwen3ForCausalLM", "Splade"]