| """ |
| Style vector utilities. |
| Helper functions for manipulating, comparing, and persisting style vectors. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from typing import List, Optional |
|
|
|
|
| def cosine_similarity(vec_a: torch.Tensor, vec_b: torch.Tensor) -> float: |
| """Compute cosine similarity between two style vectors.""" |
| if vec_a.dim() == 1: |
| vec_a = vec_a.unsqueeze(0) |
| if vec_b.dim() == 1: |
| vec_b = vec_b.unsqueeze(0) |
| sim = F.cosine_similarity(vec_a, vec_b, dim=-1) |
| return sim.item() |
|
|
|
|
| def average_style_vectors(vectors: List[torch.Tensor]) -> torch.Tensor: |
| """Compute the mean style vector from a list of vectors.""" |
| if not vectors: |
| raise ValueError("Cannot average empty list of vectors") |
| stacked = torch.stack(vectors, dim=0) |
| mean_vec = stacked.mean(dim=0) |
| |
| return F.normalize(mean_vec, p=2, dim=-1) |
|
|
|
|
| def save_style_vector(vector: torch.Tensor, path: str) -> None: |
| """Persist a style vector to disk.""" |
| torch.save(vector.detach().cpu(), path) |
|
|
|
|
| def load_style_vector(path: str) -> torch.Tensor: |
| """Load a style vector from disk.""" |
| return torch.load(path, map_location="cpu", weights_only=True) |
|
|