Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import math | |
| import urllib.request | |
| from io import BytesIO | |
| from typing import Any, Dict, List, Optional | |
| import numpy as np | |
| from PIL import Image | |
| try: | |
| import torch | |
| from transformers import CLIPModel, CLIPProcessor | |
| import faiss # type: ignore | |
| from huggingface_hub import hf_hub_download, InferenceClient | |
| except Exception as import_error: # pragma: no cover | |
| raise RuntimeError( | |
| "Required packages not found. Please install: torch, transformers, pillow, faiss-cpu, huggingface_hub" | |
| ) from import_error | |
| class EndToEndRAG: | |
| """ | |
| End-to-end multimodal RAG system using local CLIP + FAISS retrieval and remote generation via Inference API. | |
| """ | |
| def __init__( | |
| self, | |
| clip_model_name: str = "aaalaaa/multimodal-face-clip", | |
| generator_model_name: Optional[str] = "google/gemma-2b", | |
| index_path: Optional[str] = None, | |
| doc_embeddings_path: Optional[str] = None, | |
| doc_metadata_path: Optional[str] = None, | |
| device: Optional[str] = None, | |
| text_weight: float = 0.7, | |
| image_weight: float = 0.3, | |
| top_k: int = 1, | |
| max_new_tokens: int = 10, | |
| temperature: float = 0.1, | |
| ) -> None: | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.text_weight = float(text_weight) | |
| self.image_weight = float(image_weight) | |
| self.top_k = int(top_k) | |
| self.max_new_tokens = int(max_new_tokens) | |
| self.temperature = float(temperature) | |
| if not math.isclose(self.text_weight + self.image_weight, 1.0, rel_tol=1e-6): | |
| raise ValueError("text_weight + image_weight must equal 1.0") | |
| # Models: CLIP | |
| self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name) | |
| self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device) | |
| self.clip_model.eval() | |
| # Inference client for generation (remote) | |
| self.inference_client: Optional[InferenceClient] = None | |
| if generator_model_name: | |
| hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN") | |
| model_name = os.environ.get("HF_INFERENCE_MODEL", generator_model_name) | |
| self.inference_client = InferenceClient(model=model_name, token=hf_token) | |
| # Two-index stores | |
| self.text_index: Optional[faiss.Index] = None | |
| self.image_index: Optional[faiss.Index] = None | |
| self.metadata: List[Dict[str, Any]] = [] | |
| self.id_to_original: Dict[str, Dict[str, Any]] = {} | |
| # Single-index store | |
| self.index: Optional[faiss.Index] = None | |
| self.doc_embeddings: Optional[np.ndarray] = None | |
| self.doc_metadata: List[Dict[str, Any]] = [] | |
| # Load local single-index mode if provided | |
| self._load_index(index_path, doc_embeddings_path, doc_metadata_path) | |
| def default( | |
| cls, | |
| hf_token: Optional[str] = None, | |
| text_weight: float = 0.7, | |
| image_weight: float = 0.3, | |
| top_k: int = 1, | |
| max_new_tokens: int = 10, | |
| temperature: float = 0.1, | |
| device: Optional[str] = None, | |
| ) -> "EndToEndRAG": | |
| instance = cls( | |
| clip_model_name="aaalaaa/multimodal-face-clip", | |
| generator_model_name="google/gemma-2b", | |
| index_path=None, | |
| doc_embeddings_path=None, | |
| doc_metadata_path=None, | |
| device=device, | |
| text_weight=text_weight, | |
| image_weight=image_weight, | |
| top_k=top_k, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| ) | |
| # Download indices and metadata via HF Hub | |
| token = hf_token or os.environ.get("HUGGINGFACEHUB_API_TOKEN") or os.environ.get("HF_TOKEN") | |
| text_index_path = hf_hub_download( | |
| repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/text_index.faiss", token=token | |
| ) | |
| image_index_path = hf_hub_download( | |
| repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/image_index.faiss", token=token | |
| ) | |
| metadata_path = hf_hub_download( | |
| repo_id="aaalaaa/multimodal-face-clip", filename="embeddings/metadata.json", token=token | |
| ) | |
| original_path = hf_hub_download( | |
| repo_id="aaalaaa/multimodal-face-clip", filename="saved_data.json", token=token | |
| ) | |
| instance.text_index = faiss.read_index(text_index_path) | |
| instance.image_index = faiss.read_index(image_index_path) | |
| with open(metadata_path, "r", encoding="utf-8") as f: | |
| instance.metadata = json.load(f) | |
| with open(original_path, "r", encoding="utf-8") as f: | |
| original_data = json.load(f) | |
| instance.id_to_original = {str(item.get("id")): item for item in original_data} | |
| return instance | |
| def query(self, text: Optional[str], image_url: Optional[str], options: Optional[List[str]] = None) -> str: | |
| if (text is None or text.strip() == "") and (image_url is None or image_url.strip() == ""): | |
| return "ورودی معتبری ارائه نشده است. لطفاً متن پرسش یا تصویر را ارسال کنید." | |
| retrieved = self._retrieve(text=text, image_url=image_url, top_k=self.top_k) | |
| prompt = self._build_prompt(text=text, image_url=image_url, retrieved=retrieved, options=options) | |
| answer = self._generate(prompt, is_mcq=bool(options), options=options) | |
| return answer | |
| def _load_index( | |
| self, | |
| index_path: Optional[str], | |
| doc_embeddings_path: Optional[str], | |
| doc_metadata_path: Optional[str], | |
| ) -> None: | |
| if index_path and os.path.exists(index_path): | |
| self.index = faiss.read_index(index_path) | |
| if doc_embeddings_path and os.path.exists(doc_embeddings_path): | |
| self.doc_embeddings = np.load(doc_embeddings_path) | |
| if doc_metadata_path and os.path.exists(doc_metadata_path): | |
| with open(doc_metadata_path, "r", encoding="utf-8") as f: | |
| self.doc_metadata = json.load(f) | |
| if self.index is None and self.doc_embeddings is not None: | |
| self._normalize_inplace(self.doc_embeddings) | |
| dim = int(self.doc_embeddings.shape[1]) | |
| self.index = faiss.IndexFlatIP(dim) | |
| self.index.add(self.doc_embeddings.astype(np.float32)) | |
| if self.index is None: | |
| self.index = None | |
| self.doc_embeddings = None | |
| self.doc_metadata = [] | |
| def _encode_text(self, text: str) -> np.ndarray: | |
| inputs = self.clip_processor(text=[text], images=None, return_tensors="pt", padding=True).to(self.device) | |
| text_features = self.clip_model.get_text_features(**{k: v for k, v in inputs.items() if k.startswith("input_")}) | |
| text_features = torch.nn.functional.normalize(text_features, p=2, dim=-1) | |
| return text_features.detach().cpu().numpy()[0] | |
| def _encode_image(self, image: Image.Image) -> np.ndarray: | |
| inputs = self.clip_processor(text=None, images=image, return_tensors="pt").to(self.device) | |
| image_features = self.clip_model.get_image_features(**{k: v for k, v in inputs.items() if k.startswith("pixel_")}) | |
| image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1) | |
| return image_features.detach().cpu().numpy()[0] | |
| def _retrieve( | |
| self, | |
| text: Optional[str], | |
| image_url: Optional[str], | |
| top_k: int, | |
| ) -> List[Dict[str, Any]]: | |
| has_two_indices = self.text_index is not None and self.image_index is not None and len(self.metadata) > 0 | |
| query_vectors: List[np.ndarray] = [] | |
| weights: List[float] = [] | |
| if text and text.strip(): | |
| query_vectors.append(self._encode_text(text.strip())) | |
| weights.append(self.text_weight) | |
| if image_url and image_url.strip(): | |
| image = self._load_image(image_url.strip()) | |
| if image is not None: | |
| query_vectors.append(self._encode_image(image)) | |
| weights.append(self.image_weight) | |
| if not query_vectors: | |
| return [] | |
| if has_two_indices: | |
| stacked = np.stack(query_vectors).astype(np.float32) | |
| weights_arr = np.array(weights, dtype=np.float32).reshape(-1, 1) | |
| combined = (stacked * weights_arr).sum(axis=0) | |
| combined = self._normalize(combined).reshape(1, -1).astype(np.float32) | |
| text_scores, text_indices = self.text_index.search(combined, max(top_k * 3, top_k)) | |
| image_scores, image_indices = self.image_index.search(combined, max(top_k * 3, top_k)) | |
| results: Dict[str, Dict[str, Any]] = {} | |
| for score, idx in zip(text_scores[0], text_indices[0]): | |
| if idx < 0 or idx >= len(self.metadata): | |
| continue | |
| meta = self.metadata[idx] | |
| if meta.get("type") != "text": | |
| continue | |
| pid = str(meta.get("id")) | |
| entry = results.setdefault( | |
| pid, | |
| {"id": pid, "text_similarity": 0.0, "image_similarity": 0.0, "combined_similarity": 0.0}, | |
| ) | |
| entry["text_similarity"] = float(score) | |
| entry["combined_similarity"] += float(score) * self.text_weight | |
| for score, idx in zip(image_scores[0], image_indices[0]): | |
| if idx < 0 or idx >= len(self.metadata): | |
| continue | |
| meta = self.metadata[idx] | |
| if meta.get("type") != "image": | |
| continue | |
| pid = str(meta.get("id")) | |
| entry = results.setdefault( | |
| pid, | |
| {"id": pid, "text_similarity": 0.0, "image_similarity": 0.0, "combined_similarity": 0.0}, | |
| ) | |
| entry["image_similarity"] = float(score) | |
| entry["combined_similarity"] += float(score) * self.image_weight | |
| ranked = sorted(results.values(), key=lambda x: x["combined_similarity"], reverse=True) | |
| final: List[Dict[str, Any]] = [] | |
| for rank, res in enumerate(ranked[:top_k], start=1): | |
| original = self.id_to_original.get(res["id"], {}) | |
| final.append( | |
| { | |
| "id": res["id"], | |
| "rank": rank, | |
| "text_similarity": res["text_similarity"], | |
| "image_similarity": res["image_similarity"], | |
| "combined_similarity": res["combined_similarity"], | |
| "biography": original.get("cleaned_bio", ""), | |
| "image_urls": original.get("images", []), | |
| } | |
| ) | |
| return final | |
| if self.index is None or self.doc_embeddings is None or len(self.doc_metadata) == 0: | |
| return [] | |
| stacked = np.stack(query_vectors).astype(np.float32) | |
| weights_arr = np.array(weights, dtype=np.float32).reshape(-1, 1) | |
| weighted = (stacked * weights_arr).sum(axis=0) | |
| weighted = self._normalize(weighted) | |
| query = weighted.reshape(1, -1).astype(np.float32) | |
| scores, indices = self.index.search(query, top_k) | |
| scores = scores[0] | |
| indices = indices[0] | |
| results: List[Dict[str, Any]] = [] | |
| for rank, (idx, score) in enumerate(zip(indices, scores)): | |
| if idx < 0 or idx >= len(self.doc_metadata): | |
| continue | |
| meta = self.doc_metadata[idx] | |
| results.append( | |
| { | |
| "id": meta.get("id", str(idx)), | |
| "rank": int(rank + 1), | |
| "score": float(score), | |
| "title": meta.get("title", ""), | |
| "text": meta.get("text", ""), | |
| "image_path": meta.get("image_path"), | |
| "metadata": meta, | |
| } | |
| ) | |
| return results | |
| def _build_prompt( | |
| self, | |
| text: Optional[str], | |
| image_url: Optional[str], | |
| retrieved: List[Dict[str, Any]], | |
| options: Optional[List[str]] = None, | |
| ) -> str: | |
| # Notebook-style context formatting | |
| parts: List[str] = [] | |
| for i, item in enumerate(retrieved, start=1): | |
| parts.append(f"Person {i}:") | |
| bio = item.get("biography") or item.get("text") or "" | |
| parts.append(f"Biography: {bio}") | |
| imgs = item.get("image_urls") or [] | |
| if imgs: | |
| parts.append(f"Image URLs: {', '.join(imgs)}") | |
| score = item.get("combined_similarity") | |
| if score is not None: | |
| parts.append(f"Relevance Score: {float(score):.3f}") | |
| parts.append("---") | |
| context = "\n".join(parts) if parts else "(no retrieved content)" | |
| user_q = text.strip() if text else "" | |
| if options: | |
| options_text = "\n".join([f"{i}: {opt}" for i, opt in enumerate(options)]) | |
| prompt = ( | |
| f"Retrieved Information:\n{context}\n\n" | |
| f"Question: {user_q}\n\n" | |
| f"Options:\n{options_text}\n\n" | |
| "Output ONLY the chosen option number in the format \"Choice: [number]\". Do not include any other text.\n" | |
| "Choice:" | |
| ) | |
| return prompt | |
| # Free-form answer | |
| prompt = ( | |
| f"Retrieved Information:\n{context}\n\n" | |
| f"Question: {user_q}\n\n" | |
| "Answer in concise Persian:" | |
| ) | |
| return prompt | |
| def _generate(self, prompt: str, is_mcq: bool, options: Optional[List[str]]) -> str: | |
| if self.inference_client is None: | |
| return ( | |
| "سرویس تولید متن تنظیم نشده است. لطفاً یک مدل از طریق Inference API تنظیم کنید یا تولید محلی را فعال کنید." | |
| ) | |
| max_new = 10 if is_mcq else self.max_new_tokens | |
| temp = 0.1 if is_mcq else self.temperature | |
| # Prefer chat | |
| try: | |
| chat = self.inference_client.chat_completion( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| max_tokens=max_new, | |
| temperature=temp, | |
| stream=False, | |
| ) | |
| if chat and getattr(chat, "choices", None): | |
| content = getattr(chat.choices[0].message, "content", "") | |
| if isinstance(content, str) and content.strip(): | |
| return content.strip() | |
| except Exception: | |
| pass | |
| # Fallback to text generation | |
| try: | |
| out = self.inference_client.text_generation( | |
| prompt, | |
| max_new_tokens=max_new, | |
| temperature=temp, | |
| do_sample=temp > 0, | |
| return_full_text=False, | |
| details=False, | |
| stream=False, | |
| ) | |
| if isinstance(out, str) and out.strip(): | |
| return out.strip() | |
| gen = getattr(out, "generated_text", None) | |
| if isinstance(gen, str) and gen.strip(): | |
| return gen.strip() | |
| return "" | |
| except Exception as e: | |
| return f"خطا در تولید پاسخ: {type(e).__name__}: {e}" | |
| def _normalize(v: np.ndarray) -> np.ndarray: | |
| denom = np.linalg.norm(v) + 1e-12 | |
| return (v / denom).astype(np.float32) | |
| def _normalize_inplace(mat: np.ndarray) -> None: | |
| norms = np.linalg.norm(mat, axis=1, keepdims=True) + 1e-12 | |
| mat /= norms | |
| def _load_image(image_url: str) -> Optional[Image.Image]: | |
| try: | |
| if image_url.startswith("http://") or image_url.startswith("https://"): | |
| with urllib.request.urlopen(image_url, timeout=10) as resp: | |
| data = resp.read() | |
| return Image.open(BytesIO(data)).convert("RGB") | |
| if os.path.exists(image_url): | |
| return Image.open(image_url).convert("RGB") | |
| except Exception: | |
| return None | |
| return None |