splade-code-8B / splade.py
Tom Aarsen
Use snapshot_download to avoid PEFT Windows issue
c7dfa27
raw
history blame
6.61 kB
"""
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B in the `lora/` subfolder;
`Qwen3ForCausalLM.from_pretrained` loads the base model and applies the adapter.
"""
import os
import torch
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode
# The adapter lives in this subfolder rather than at the repo root so that
# `find_adapter_config_file` doesn't trigger transformers' auto-PEFT path,
# which would otherwise redirect hub loads to `Qwen/Qwen3-8B` and lose the
# `auto_map` routing to the classes in this file.
ADAPTER_SUBFOLDER = "lora"
class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
def tie_weights(self, *args, **kwargs):
"""Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
if (
self.config.tie_word_embeddings
and hasattr(self, "lm_head")
and hasattr(self, "model")
):
self.lm_head.weight = self.model.embed_tokens.weight
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard("lm_head.weight")
else:
super().tie_weights(*args, **kwargs)
def _init_weights(self, module):
"""Skip lm_head init when it will be tied to embed_tokens later."""
if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
return
super()._init_weights(module)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from huggingface_hub import snapshot_download
from peft import PeftConfig, PeftModel
token = kwargs.get("token")
# Resolve the adapter to a local path before handing it to PEFT.
# PEFT's `subfolder=` kwarg uses `os.path.join` on Windows, producing
# backslashed hub paths that break the safetensors-vs-bin fallback.
if os.path.isdir(pretrained_model_name_or_path):
adapter_path = os.path.join(pretrained_model_name_or_path, ADAPTER_SUBFOLDER)
else:
local_repo = snapshot_download(
pretrained_model_name_or_path,
allow_patterns=[f"{ADAPTER_SUBFOLDER}/*"],
token=token,
)
adapter_path = os.path.join(local_repo, ADAPTER_SUBFOLDER)
if not os.path.isfile(os.path.join(adapter_path, "adapter_config.json")):
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
peft_config = PeftConfig.from_pretrained(adapter_path, token=token)
# Use provided splade config (has is_causal=False) or load it from the adapter repo
config = kwargs.pop("config", None)
if config is None or not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, token=token)
base_model = super().from_pretrained(
peft_config.base_model_name_or_path,
*model_args,
config=config,
**kwargs,
)
return PeftModel.from_pretrained(base_model, adapter_path, token=token)
class SpladeConfig(PretrainedConfig):
model_type = "qwen3"
def __init__(
self,
model_name_or_path: str = "Qwen/Qwen3-8B",
attn_implementation: str = "flash_attention_2",
bidirectional: bool = True, # only for decoder models
padding_side: str = "left",
**kwargs,
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.attn_implementation = attn_implementation
self.bidirectional = bidirectional
self.padding_side = padding_side
class Splade(PreTrainedModel):
config_class = SpladeConfig
# methods for MTEB's interface
similarity = similarity
encode = encode
def __init__(self, config, weights_path=None, token=None):
super().__init__(config)
self.name = "splade"
base_cfg = AutoConfig.from_pretrained(
weights_path,
attn_implementation=config.attn_implementation,
torch_dtype="auto",
token=token,
)
self.tokenizer = prepare_tokenizer(
weights_path, padding_side=config.padding_side
)
if is_flash_attn_2_available():
config.attn_implementation = "flash_attention_2"
else:
config.attn_implementation = "sdpa"
self.model = Qwen3ForCausalLM.from_pretrained(
weights_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=config.attn_implementation,
token=token,
)
def save_pretrained(self, save_directory, *args, **kwargs):
self.model.save_pretrained(os.path.join(save_directory, ADAPTER_SUBFOLDER))
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
token = kwargs.get("token", None)
config = SpladeConfig.from_pretrained(
model_name_or_path,
token=token,
)
model = cls(config, weights_path=model_name_or_path, token=token)
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
return model
def forward(self, **tokens):
output = self.model(**tokens)
splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
return (splade_reps,)
def get_width(self):
return self.model.config.vocab_size
def create_batch_dict(self, input_texts, max_length):
return self.tokenizer(
input_texts,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
)
__all__ = ["Qwen3ForCausalLM", "Splade"]