aysinghal's picture
Add embedding smoke test script
824cb05 verified
"""Cosine similarity smoke test on final GPT-2 LLM2Vec checkpoint."""
from __future__ import annotations
import argparse
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from gpt2_llm2vec.models import get_model_class
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Print cosine similarities for a few sentence pairs.")
p.add_argument("--model-path", default="gpt2_llm2vec/checkpoints/gpt2_llm2vec_final")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
return p.parse_args()
def mean_pool(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).type_as(last_hidden)
summed = (last_hidden * mask).sum(dim=1)
lengths = mask.sum(dim=1).clamp(min=1e-9)
return summed / lengths
@torch.no_grad()
def encode(model, tokenizer, texts: list[str], device: str) -> torch.Tensor:
enc = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
enc = {k: v.to(device) for k, v in enc.items()}
out = model.transformer(**enc, return_dict=True)
emb = mean_pool(out.last_hidden_state, enc["attention_mask"])
return F.normalize(emb, p=2, dim=1)
def main() -> None:
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_cls = get_model_class("gpt2-large")
model = model_cls.from_pretrained(args.model_path)
model.eval()
model.to(args.device)
pairs = [
("The cat sits on the mat.", "A cat is resting on a rug."),
("Python is a programming language.", "Coding in Python is popular."),
("The stock market rose today.", "It is raining heavily outside."),
("How do I sort a list in Python?", "Use sorted() or list.sort() in Python."),
("Neural networks learn from data.", "Pizza tastes best when hot."),
]
for a, b in pairs:
e = encode(model, tokenizer, [a, b], args.device)
sim = (e[0] * e[1]).sum().item()
print(f"cos_sim={sim:.4f}")
print(f" A: {a}")
print(f" B: {b}")
print()
if __name__ == "__main__":
main()