Spaces:
Running
Running
from PIL import Image | |
import torch | |
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import json | |
import spacy | |
import time | |
# Load models and resources | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
text_encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
tokenizer = T5Tokenizer.from_pretrained("t5-small") | |
generator = T5ForConditionalGeneration.from_pretrained("t5-small") | |
try: | |
nlp = spacy.load("en_core_web_sm") | |
except: | |
import spacy.cli | |
spacy.cli.download("en_core_web_sm") | |
nlp = spacy.load("en_core_web_sm") | |
# Load FAISS index and captions | |
faiss_index = faiss.read_index("./faiss_index.idx") | |
with open("./captions.json", "r", encoding="utf-8") as f: | |
captions = json.load(f) | |
def extract_image_features(image): | |
""" | |
Extract image features using CLIP model. | |
Input: PIL Image or image path (str). | |
Output: Normalized image embedding (numpy array). | |
""" | |
# try: | |
# if isinstance(image, str): | |
# image = Image.open(image).convert("RGB") | |
# else: | |
# image = image.convert("RGB") | |
# inputs = clip_processor(images=image, return_tensors="pt") | |
# with torch.no_grad(): | |
# features = clip_model.get_image_features(**inputs) | |
# features = torch.nn.functional.normalize(features, p=2, dim=-1) | |
# return features.squeeze(0).cpu().numpy().astype("float32") | |
# except Exception as e: | |
# print(f"Error extracting features: {e}") | |
# return None | |
try: | |
# Convert NumPy array to PIL if needed | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image.astype("uint8")).convert("RGB") | |
elif isinstance(image, str): | |
image = Image.open(image).convert("RGB") | |
else: | |
image = image.convert("RGB") | |
inputs = clip_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
features = clip_model.get_image_features(**inputs) | |
features = torch.nn.functional.normalize(features, p=2, dim=-1) | |
return features.squeeze(0).cpu().numpy().astype("float32") | |
except Exception as e: | |
print(f"Error extracting features: {e}") | |
return None | |
def retrieve_similar_captions(image_embedding, k=5): | |
""" | |
Retrieve k most similar captions using FAISS index. | |
Input: Image embedding (numpy array). | |
Output: List of captions. | |
""" | |
if image_embedding.ndim == 1: | |
image_embedding = image_embedding.reshape(1, -1) | |
D, I = faiss_index.search(image_embedding, k) | |
return [captions[i] for i in I[0]] | |
def extract_location_names(texts): | |
""" | |
Extract location names from captions using spaCy. | |
Input: List of captions. | |
Output: List of unique location names. | |
""" | |
names = [] | |
for text in texts: | |
doc = nlp(text) | |
for ent in doc.ents: | |
if ent.label_ in ["GPE", "LOC", "FAC"]: | |
names.append(ent.text) | |
return list(set(names)) | |
def generate_caption_from_retrieved(retrieved_captions): | |
""" | |
Generate a caption from retrieved captions using T5. | |
Input: List of retrieved captions. | |
Output: Generated caption (str). | |
""" | |
locations = extract_location_names(retrieved_captions) | |
location_hint = f"The place might be: {', '.join(locations)}. " if locations else "" | |
prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:" | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) | |
outputs = generator.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_length=300, | |
num_beams=5, | |
early_stopping=True | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def generate_rag_caption(image): | |
""" | |
Generate a RAG-based caption for an image. | |
Input: PIL Image or image path (str). | |
Output: Caption (str). | |
""" | |
try: | |
start = time.time() | |
embedding = extract_image_features(image) | |
print("⏱️ CLIP feature extraction:", round(time.time() - start, 2), "s") | |
if embedding is None: | |
return "Failed to process image." | |
start = time.time() | |
retrieved = retrieve_similar_captions(embedding, k=5) | |
print("⏱️ Caption retrieval:", round(time.time() - start, 2), "s") | |
if not retrieved: | |
return "No similar captions found." | |
start = time.time() | |
caption = generate_caption_from_retrieved(retrieved) | |
print("⏱️ Caption generation:", round(time.time() - start, 2), "s") | |
return caption | |
except Exception as e: | |
print(f"Error in RAG captioning: {e}") | |
return "Something went wrong during caption generation." |