directionality_probe / packaged_probe_model.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import os
import sys
from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput
try:
from protify.base_models.supported_models import all_presets_with_paths
from protify.pooler import Pooler
from protify.probes.get_probe import rebuild_probe_from_saved_config
except ImportError:
current_dir = os.path.dirname(os.path.abspath(__file__))
candidate_paths = [
current_dir,
os.path.dirname(current_dir),
os.path.dirname(os.path.dirname(current_dir)),
os.path.join(current_dir, "src"),
]
for candidate in candidate_paths:
if os.path.isdir(candidate) and candidate not in sys.path:
sys.path.insert(0, candidate)
from protify.base_models.supported_models import all_presets_with_paths
from protify.pooler import Pooler
from protify.probes.get_probe import rebuild_probe_from_saved_config
class PackagedProbeConfig(PretrainedConfig):
model_type = "packaged_probe"
def __init__(
self,
base_model_name: str = "",
probe_type: str = "linear",
probe_config: Optional[Dict[str, Any]] = None,
tokenwise: bool = False,
matrix_embed: bool = False,
pooling_types: Optional[list[str]] = None,
task_type: str = "singlelabel",
num_labels: int = 2,
ppi: bool = False,
add_token_ids: bool = False,
sep_token_id: Optional[int] = None,
**kwargs,
):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.probe_type = probe_type
self.probe_config = {} if probe_config is None else probe_config
self.tokenwise = tokenwise
self.matrix_embed = matrix_embed
self.pooling_types = ["mean"] if pooling_types is None else pooling_types
self.task_type = task_type
self.num_labels = num_labels
self.ppi = ppi
self.add_token_ids = add_token_ids
self.sep_token_id = sep_token_id
class PackagedProbeModel(PreTrainedModel):
config_class = PackagedProbeConfig
base_model_prefix = "backbone"
all_tied_weights_keys = {}
def __init__(
self,
config: PackagedProbeConfig,
base_model: Optional[nn.Module] = None,
probe: Optional[nn.Module] = None,
):
super().__init__(config)
self.config = config
self.backbone = self._load_base_model() if base_model is None else base_model
self.probe = self._load_probe() if probe is None else probe
self.pooler = Pooler(self.config.pooling_types)
def _load_base_model(self) -> nn.Module:
if self.config.base_model_name in all_presets_with_paths:
model_path = all_presets_with_paths[self.config.base_model_name]
else:
model_path = self.config.base_model_name
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
model.eval()
return model
def _load_probe(self) -> nn.Module:
return rebuild_probe_from_saved_config(
probe_type=self.config.probe_type,
tokenwise=self.config.tokenwise,
probe_config=self.config.probe_config,
)
@staticmethod
def _extract_hidden_states(backbone_output: Any) -> torch.Tensor:
if isinstance(backbone_output, tuple):
return backbone_output[0]
if hasattr(backbone_output, "last_hidden_state"):
return backbone_output.last_hidden_state
if isinstance(backbone_output, torch.Tensor):
return backbone_output
raise ValueError("Unsupported backbone output format for packaged probe model")
@staticmethod
def _extract_attentions(backbone_output: Any) -> Optional[torch.Tensor]:
if hasattr(backbone_output, "attentions"):
return backbone_output.attentions
return None
def _build_ppi_segment_masks(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if token_type_ids is not None and torch.any(token_type_ids == 1):
mask_a = ((token_type_ids == 0) & (attention_mask == 1)).long()
mask_b = ((token_type_ids == 1) & (attention_mask == 1)).long()
assert torch.all(mask_a.sum(dim=1) > 0), "PPI token_type_ids produced empty segment A"
assert torch.all(mask_b.sum(dim=1) > 0), "PPI token_type_ids produced empty segment B"
return mask_a, mask_b
assert self.config.sep_token_id is not None, "sep_token_id is required for PPI fallback segmentation"
batch_size, seq_len = input_ids.shape
mask_a = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
mask_b = torch.zeros((batch_size, seq_len), dtype=torch.long, device=input_ids.device)
for batch_idx in range(batch_size):
valid_positions = torch.where(attention_mask[batch_idx] == 1)[0]
sep_positions = torch.where((input_ids[batch_idx] == self.config.sep_token_id) & (attention_mask[batch_idx] == 1))[0]
if len(valid_positions) == 0:
continue
if len(sep_positions) >= 2:
first_sep = int(sep_positions[0].item())
second_sep = int(sep_positions[1].item())
mask_a[batch_idx, :first_sep + 1] = 1
mask_b[batch_idx, first_sep + 1:second_sep + 1] = 1
elif len(sep_positions) == 1:
first_sep = int(sep_positions[0].item())
mask_a[batch_idx, :first_sep + 1] = 1
mask_b[batch_idx, first_sep + 1: int(valid_positions[-1].item()) + 1] = 1
else:
midpoint = len(valid_positions) // 2
mask_a[batch_idx, valid_positions[:midpoint]] = 1
mask_b[batch_idx, valid_positions[midpoint:]] = 1
assert torch.all(mask_a.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment A"
assert torch.all(mask_b.sum(dim=1) > 0), "PPI fallback segmentation produced empty segment B"
return mask_a, mask_b
def _build_probe_inputs(
self,
hidden_states: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
attentions: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.config.ppi and (not self.config.matrix_embed) and (not self.config.tokenwise):
mask_a, mask_b = self._build_ppi_segment_masks(input_ids, attention_mask, token_type_ids)
vec_a = self.pooler(hidden_states, attention_mask=mask_a, attentions=attentions)
vec_b = self.pooler(hidden_states, attention_mask=mask_b, attentions=attentions)
return torch.cat((vec_a, vec_b), dim=-1), None
if self.config.matrix_embed or self.config.tokenwise:
return hidden_states, attention_mask
pooled = self.pooler(hidden_states, attention_mask=attention_mask, attentions=attentions)
return pooled, None
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> SequenceClassifierOutput | TokenClassifierOutput:
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
requires_attentions = "parti" in self.config.pooling_types and (not self.config.matrix_embed) and (not self.config.tokenwise)
backbone_kwargs: Dict[str, Any] = {"input_ids": input_ids, "attention_mask": attention_mask}
if requires_attentions:
backbone_kwargs["output_attentions"] = True
backbone_output = self.backbone(**backbone_kwargs)
hidden_states = self._extract_hidden_states(backbone_output)
attentions = self._extract_attentions(backbone_output)
if requires_attentions:
assert attentions is not None, "parti pooling requires base model attentions"
probe_embeddings, probe_attention_mask = self._build_probe_inputs(
hidden_states=hidden_states,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
attentions=attentions,
)
if self.config.probe_type == "linear":
return self.probe(embeddings=probe_embeddings, labels=labels)
if self.config.probe_type == "transformer":
forward_kwargs: Dict[str, Any] = {"embeddings": probe_embeddings, "labels": labels}
if probe_attention_mask is not None:
forward_kwargs["attention_mask"] = probe_attention_mask
if self.config.add_token_ids and token_type_ids is not None and probe_attention_mask is not None:
forward_kwargs["token_type_ids"] = token_type_ids
return self.probe(**forward_kwargs)
if self.config.probe_type in ["retrievalnet", "lyra"]:
return self.probe(embeddings=probe_embeddings, attention_mask=probe_attention_mask, labels=labels)
raise ValueError(f"Unsupported probe type for packaged model: {self.config.probe_type}")