Spaces:
Runtime error
Runtime error
| """ | |
| TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone. | |
| Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1) | |
| for binary deepfake detection (0 = Real, 1 = Fake). | |
| Usage: | |
| from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG | |
| classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf") | |
| classifier = get_peft_model(classifier, TMOS_LORA_CONFIG) | |
| logit = classifier(input_ids=..., pixel_values=..., attention_mask=...) | |
| loss = nn.BCEWithLogitsLoss()(logit, label) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import LlavaForConditionalGeneration | |
| from peft import LoraConfig | |
| # βββ LoRA Configuration ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Massive expansion: r=64 across ALL linear layers in the LLM backbone. | |
| # We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision), | |
| # and linear_1/linear_2 (multi-modal projector) from LoRA to keep | |
| # the vision encoder frozen and only adapt the language transformer. | |
| TMOS_LORA_CONFIG = LoraConfig( | |
| r=64, | |
| lora_alpha=128, # 2x rank as a common heuristic | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type=None, # Custom classifier β not a causal LM | |
| modules_to_save=["classifier"], # Always train the classification head | |
| ) | |
| class TMOSClassifier(nn.Module): | |
| """ | |
| Binary classifier built on the LLaVA transformer backbone. | |
| Architecture: | |
| pixel_values βββΊ CLIP Vision Tower βββΊ Multi-Modal Projector βββ | |
| ββββΊ LLaMA Transformer βββΊ last_hidden_state[:, -1, :] βββΊ classifier βββΊ logit | |
| input_ids βββΊ Token Embedding ββββββββββββββββββββββββββββββββββ | |
| The lm_head is never used. We extract the final token's hidden state | |
| and pass it through a learned nn.Linear(hidden_size, 1) head. | |
| """ | |
| def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None): | |
| super().__init__() | |
| # Load the full LLaVA model (we need vision tower + projector + LLM) | |
| self.base = LlavaForConditionalGeneration.from_pretrained( | |
| base_model_id, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| device_map=device_map, | |
| token=token, | |
| ) | |
| hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B | |
| # Freeze the lm_head β we won't use it, but freezing prevents | |
| # wasted gradient computation if PEFT accidentally wraps it. | |
| for param in self.base.lm_head.parameters(): | |
| param.requires_grad = False | |
| # Keep the classifier head in fp32 for numerical stability. | |
| self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32) | |
| nn.init.xavier_uniform_(self.classifier.weight) | |
| nn.init.zeros_(self.classifier.bias) | |
| def forward( | |
| self, | |
| input_ids=None, | |
| pixel_values=None, | |
| attention_mask=None, | |
| labels=None, # float tensor of shape (B,) β 0.0=real, 1.0=fake | |
| **kwargs, # absorb extra keys from data collator | |
| ): | |
| """ | |
| Single deterministic forward pass β logit + optional BCE loss. | |
| Returns: | |
| dict with keys: | |
| "logit": (B, 1) raw logit | |
| "loss": scalar BCE loss (only if labels provided) | |
| """ | |
| # ββ 1. Forward through the LLaVA backbone ββ | |
| # We call the internal model (vision + projector + LLM) directly, | |
| # asking for hidden states, NOT for language-model logits. | |
| outputs = self.base.model( | |
| input_ids=input_ids, | |
| pixel_values=pixel_values, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| ) | |
| # last_hidden_state: (B, seq_len, hidden_size) | |
| last_hidden_state = outputs.last_hidden_state | |
| # ββ 2. Pool: extract the final non-padded token per sequence ββ | |
| if attention_mask is not None: | |
| # Sum of mask gives the sequence length (excluding padding) | |
| # Index of the last real token = seq_lengths - 1 | |
| seq_lengths = attention_mask.sum(dim=1).long() - 1 | |
| # Clamp to valid range | |
| seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1) | |
| # Gather the hidden state at each sequence's last real token | |
| pooled = last_hidden_state[ | |
| torch.arange(last_hidden_state.size(0), device=last_hidden_state.device), | |
| seq_lengths, | |
| ] | |
| else: | |
| # No mask β just take the last position | |
| pooled = last_hidden_state[:, -1, :] | |
| # Replace non-finite activations defensively before the classifier. | |
| pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4) | |
| # Match classifier device to pooled features when model is sharded/offloaded. | |
| if self.classifier.weight.device != pooled.device: | |
| self.classifier = self.classifier.to(pooled.device) | |
| # ββ 3. Classify ββ | |
| logit = self.classifier(pooled.float()) # (B, 1) | |
| logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0) | |
| result = {"logit": logit} | |
| # ββ 4. Loss ββ | |
| if labels is not None: | |
| labels = labels.to(logit.dtype).to(logit.device) | |
| if labels.dim() == 1: | |
| labels = labels.unsqueeze(1) # (B,) β (B, 1) | |
| loss_fn = nn.BCEWithLogitsLoss() | |
| result["loss"] = loss_fn(logit, labels) | |
| return result | |
| def prepare_inputs_for_generation(self, *args, **kwargs): | |
| """Stub required by PEFT β we never generate text.""" | |
| raise NotImplementedError("TMOSClassifier does not support generation.") | |
| def gradient_checkpointing_enable(self, **kwargs): | |
| """Delegate to the base model for HF Trainer compatibility.""" | |
| self.base.model.gradient_checkpointing_enable(**kwargs) | |
| def config(self): | |
| """Expose the base model config for PEFT.""" | |
| return self.base.config | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| # βββ Standalone Test ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| print("Testing TMOSClassifier...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clf = TMOSClassifier( | |
| base_model_id="llava-hf/llava-1.5-7b-hf", | |
| torch_dtype=torch.float16, | |
| token=HF_TOKEN, | |
| ) | |
| clf.to(device) | |
| # Print parameter counts | |
| total = sum(p.numel() for p in clf.parameters()) | |
| trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad) | |
| print(f"Total params: {total:>12,}") | |
| print(f"Trainable params: {trainable:>12,}") | |
| print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}") | |
| # Smoke test with dummy input | |
| from transformers import AutoProcessor | |
| processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN) | |
| processor.patch_size = 14 | |
| processor.vision_feature_select_strategy = "default" | |
| from PIL import Image | |
| dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128)) | |
| inputs = processor( | |
| text="USER: <image>\nIs this real?\nASSISTANT:", | |
| images=dummy_img, | |
| return_tensors="pt", | |
| ).to(device) | |
| labels = torch.tensor([1.0], device=device) # fake | |
| with torch.no_grad(): | |
| out = clf(**inputs, labels=labels) | |
| print(f"Logit: {out['logit'].item():.4f}") | |
| print(f"Loss: {out['loss'].item():.4f}") | |
| print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}") | |
| print("Test passed.") | |