| import torch |
| import pickle |
| import json |
| from huggingface_hub import hf_hub_download |
| from sentence_transformers import SentenceTransformer |
| import numpy as np |
| import importlib.util |
| from pathlib import Path |
|
|
| class Anime2Vec: |
| """ |
| A high-level wrapper to easily use the hikka-forge-anime2vec model. |
| It automatically downloads all required artifacts from the Hugging Face Hub. |
| """ |
| def __init__(self, repo_id: str = "Lorg0n/hikka-forge-anime2vec", device: str = None): |
| print(f"🚀 Initializing Anime2Vec from repository: {repo_id}") |
| |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") |
| print(f" - Using device: {self.device}") |
| |
| cache_dir = Path.home() / ".cache" / "hikka-forge" |
| |
| |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json", cache_dir=cache_dir) |
| model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", cache_dir=cache_dir) |
| model_code_path = hf_hub_download(repo_id=repo_id, filename="model.py", cache_dir=cache_dir) |
| le_genre_path = hf_hub_download(repo_id=repo_id, filename="le_genre.pkl", cache_dir=cache_dir) |
| le_studio_path = hf_hub_download(repo_id=repo_id, filename="le_studio.pkl", cache_dir=cache_dir) |
| le_type_path = hf_hub_download(repo_id=repo_id, filename="le_type.pkl", cache_dir=cache_dir) |
| |
| |
| with open(config_path, 'r') as f: |
| self.config = json.load(f) |
| with open(le_genre_path, 'rb') as f: |
| self.le_genre = pickle.load(f) |
| with open(le_studio_path, 'rb') as f: |
| self.le_studio = pickle.load(f) |
| with open(le_type_path, 'rb') as f: |
| self.le_type = pickle.load(f) |
|
|
| |
| spec = importlib.util.spec_from_file_location("AnimeEmbeddingModel", model_code_path) |
| model_module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(model_module) |
| AnimeEmbeddingModel = model_module.AnimeEmbeddingModel |
| |
| |
| self.model = AnimeEmbeddingModel( |
| vocab_sizes=self.config['vocab_sizes'], |
| embedding_dims=self.config['embedding_dims'], |
| text_embedding_size=self.config['text_embedding_size'] |
| ) |
| self.model.load_state_dict(torch.load(model_path, map_location=self.device)) |
| self.model.to(self.device) |
| self.model.eval() |
| |
| |
| self.text_encoder = SentenceTransformer( |
| 'Lorg0n/hikka-forge-paraphrase-multilingual-MiniLM-L12-v2', |
| device=self.device |
| ) |
| print("✅ Initialization complete. Model is ready to use.") |
|
|
| @torch.no_grad() |
| def encode(self, anime_data: dict) -> np.ndarray: |
| """ |
| Encodes a dictionary of anime data into a 512-dimensional vector. |
| """ |
| text_fields = [ |
| anime_data.get('ua_description', ''), anime_data.get('en_description', ''), |
| anime_data.get('ua_title', ''), anime_data.get('en_title', ''), |
| anime_data.get('original_title', ''), "; ".join(anime_data.get('alternate_names', [])) |
| ] |
| text_embeddings = self.text_encoder.encode(text_fields, convert_to_tensor=True) |
| |
| known_genres = [g for g in anime_data.get('genres', []) if g in self.le_genre.classes_] |
| genre_ids = self.le_genre.transform(known_genres) if known_genres else [0] |
| |
| try: |
| studio_id = self.le_studio.transform([anime_data.get('studio', 'UNKNOWN')])[0] |
| except ValueError: |
| studio_id = self.le_studio.transform(['UNKNOWN'])[0] |
|
|
| try: |
| type_id = self.le_type.transform([anime_data.get('type', 'UNKNOWN')])[0] |
| except ValueError: |
| type_id = self.le_type.transform(['UNKNOWN'])[0] |
| |
| numerical = torch.tensor(anime_data.get('numerical_features', [0.0]*6), dtype=torch.float32) |
|
|
| batch = { |
| 'precomputed_ua_desc': text_embeddings[0], 'precomputed_en_desc': text_embeddings[1], |
| 'precomputed_ua_title': text_embeddings[2], 'precomputed_en_title': text_embeddings[3], |
| 'precomputed_original_title': text_embeddings[4], 'precomputed_alternate_names': text_embeddings[5], |
| 'genres': torch.tensor(genre_ids, dtype=torch.long), |
| 'studio': torch.tensor(studio_id, dtype=torch.long), |
| 'type': torch.tensor(type_id, dtype=torch.long), |
| 'numerical': numerical |
| } |
|
|
| for key, tensor in batch.items(): |
| batch[key] = tensor.unsqueeze(0).to(self.device) |
| batch['genres_mask'] = (batch['genres'] != 0).long() |
| |
| embedding = self.model(batch) |
| return embedding.squeeze().cpu().numpy() |