TrueFrame / scripts /tmos_classifier.py
NeelakshSaxena's picture
Deploy auto GPU fallback + FastAPI /predict
8d017ca verified
"""
TMOS_Classifier: Binary classification head on top of LLaVA's transformer backbone.
Strips the autoregressive lm_head and replaces it with a single nn.Linear(hidden_size, 1)
for binary deepfake detection (0 = Real, 1 = Fake).
Usage:
from tmos_classifier import TMOSClassifier, TMOS_LORA_CONFIG
classifier = TMOSClassifier(base_model_id="llava-hf/llava-1.5-7b-hf")
classifier = get_peft_model(classifier, TMOS_LORA_CONFIG)
logit = classifier(input_ids=..., pixel_values=..., attention_mask=...)
loss = nn.BCEWithLogitsLoss()(logit, label)
"""
import torch
import torch.nn as nn
from transformers import LlavaForConditionalGeneration
from peft import LoraConfig
# ─── LoRA Configuration ──────────────────────────────────────────────
# Massive expansion: r=64 across ALL linear layers in the LLM backbone.
# We exclude lm_head (we discard it), fc1/fc2/out_proj (CLIP vision),
# and linear_1/linear_2 (multi-modal projector) from LoRA to keep
# the vision encoder frozen and only adapt the language transformer.
TMOS_LORA_CONFIG = LoraConfig(
r=64,
lora_alpha=128, # 2x rank as a common heuristic
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_dropout=0.1,
bias="none",
task_type=None, # Custom classifier β€” not a causal LM
modules_to_save=["classifier"], # Always train the classification head
)
class TMOSClassifier(nn.Module):
"""
Binary classifier built on the LLaVA transformer backbone.
Architecture:
pixel_values ──► CLIP Vision Tower ──► Multi-Modal Projector ──┐
β”œβ”€β”€β–Ί LLaMA Transformer ──► last_hidden_state[:, -1, :] ──► classifier ──► logit
input_ids ──► Token Embedding β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
The lm_head is never used. We extract the final token's hidden state
and pass it through a learned nn.Linear(hidden_size, 1) head.
"""
def __init__(self, base_model_id, torch_dtype=torch.float16, device_map="auto", token=None):
super().__init__()
# Load the full LLaVA model (we need vision tower + projector + LLM)
self.base = LlavaForConditionalGeneration.from_pretrained(
base_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
device_map=device_map,
token=token,
)
hidden_size = self.base.config.text_config.hidden_size # 4096 for 7B
# Freeze the lm_head β€” we won't use it, but freezing prevents
# wasted gradient computation if PEFT accidentally wraps it.
for param in self.base.lm_head.parameters():
param.requires_grad = False
# Keep the classifier head in fp32 for numerical stability.
self.classifier = nn.Linear(hidden_size, 1, dtype=torch.float32)
nn.init.xavier_uniform_(self.classifier.weight)
nn.init.zeros_(self.classifier.bias)
def forward(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
labels=None, # float tensor of shape (B,) β€” 0.0=real, 1.0=fake
**kwargs, # absorb extra keys from data collator
):
"""
Single deterministic forward pass β†’ logit + optional BCE loss.
Returns:
dict with keys:
"logit": (B, 1) raw logit
"loss": scalar BCE loss (only if labels provided)
"""
# ── 1. Forward through the LLaVA backbone ──
# We call the internal model (vision + projector + LLM) directly,
# asking for hidden states, NOT for language-model logits.
outputs = self.base.model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
return_dict=True,
)
# last_hidden_state: (B, seq_len, hidden_size)
last_hidden_state = outputs.last_hidden_state
# ── 2. Pool: extract the final non-padded token per sequence ──
if attention_mask is not None:
# Sum of mask gives the sequence length (excluding padding)
# Index of the last real token = seq_lengths - 1
seq_lengths = attention_mask.sum(dim=1).long() - 1
# Clamp to valid range
seq_lengths = seq_lengths.clamp(min=0, max=last_hidden_state.size(1) - 1)
# Gather the hidden state at each sequence's last real token
pooled = last_hidden_state[
torch.arange(last_hidden_state.size(0), device=last_hidden_state.device),
seq_lengths,
]
else:
# No mask β†’ just take the last position
pooled = last_hidden_state[:, -1, :]
# Replace non-finite activations defensively before the classifier.
pooled = torch.nan_to_num(pooled, nan=0.0, posinf=1e4, neginf=-1e4)
# Match classifier device to pooled features when model is sharded/offloaded.
if self.classifier.weight.device != pooled.device:
self.classifier = self.classifier.to(pooled.device)
# ── 3. Classify ──
logit = self.classifier(pooled.float()) # (B, 1)
logit = torch.nan_to_num(logit, nan=0.0, posinf=20.0, neginf=-20.0)
result = {"logit": logit}
# ── 4. Loss ──
if labels is not None:
labels = labels.to(logit.dtype).to(logit.device)
if labels.dim() == 1:
labels = labels.unsqueeze(1) # (B,) β†’ (B, 1)
loss_fn = nn.BCEWithLogitsLoss()
result["loss"] = loss_fn(logit, labels)
return result
def prepare_inputs_for_generation(self, *args, **kwargs):
"""Stub required by PEFT β€” we never generate text."""
raise NotImplementedError("TMOSClassifier does not support generation.")
def gradient_checkpointing_enable(self, **kwargs):
"""Delegate to the base model for HF Trainer compatibility."""
self.base.model.gradient_checkpointing_enable(**kwargs)
@property
def config(self):
"""Expose the base model config for PEFT."""
return self.base.config
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
# ─── Standalone Test ──────────────────────────────────────────────────
if __name__ == "__main__":
import os
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
print("Testing TMOSClassifier...")
device = "cuda" if torch.cuda.is_available() else "cpu"
clf = TMOSClassifier(
base_model_id="llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.float16,
token=HF_TOKEN,
)
clf.to(device)
# Print parameter counts
total = sum(p.numel() for p in clf.parameters())
trainable = sum(p.numel() for p in clf.parameters() if p.requires_grad)
print(f"Total params: {total:>12,}")
print(f"Trainable params: {trainable:>12,}")
print(f"Classifier head: {sum(p.numel() for p in clf.classifier.parameters()):,}")
# Smoke test with dummy input
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf", token=HF_TOKEN)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
from PIL import Image
dummy_img = Image.new("RGB", (336, 336), color=(128, 128, 128))
inputs = processor(
text="USER: <image>\nIs this real?\nASSISTANT:",
images=dummy_img,
return_tensors="pt",
).to(device)
labels = torch.tensor([1.0], device=device) # fake
with torch.no_grad():
out = clf(**inputs, labels=labels)
print(f"Logit: {out['logit'].item():.4f}")
print(f"Loss: {out['loss'].item():.4f}")
print(f"Prob: {torch.sigmoid(out['logit']).item():.4f}")
print("Test passed.")