RAGExplo1234 / inference.py
supib4132's picture
Update inference.py
4e2f70b verified
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import spacy
import time
# Load models and resources
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = T5Tokenizer.from_pretrained("t5-small")
generator = T5ForConditionalGeneration.from_pretrained("t5-small")
try:
nlp = spacy.load("en_core_web_sm")
except:
import spacy.cli
spacy.cli.download("en_core_web_sm")
nlp = spacy.load("en_core_web_sm")
# Load FAISS index and captions
faiss_index = faiss.read_index("./faiss_index.idx")
with open("./captions.json", "r", encoding="utf-8") as f:
captions = json.load(f)
def extract_image_features(image):
"""
Extract image features using CLIP model.
Input: PIL Image or image path (str).
Output: Normalized image embedding (numpy array).
"""
# try:
# if isinstance(image, str):
# image = Image.open(image).convert("RGB")
# else:
# image = image.convert("RGB")
# inputs = clip_processor(images=image, return_tensors="pt")
# with torch.no_grad():
# features = clip_model.get_image_features(**inputs)
# features = torch.nn.functional.normalize(features, p=2, dim=-1)
# return features.squeeze(0).cpu().numpy().astype("float32")
# except Exception as e:
# print(f"Error extracting features: {e}")
# return None
try:
# Convert NumPy array to PIL if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image.astype("uint8")).convert("RGB")
elif isinstance(image, str):
image = Image.open(image).convert("RGB")
else:
image = image.convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
features = torch.nn.functional.normalize(features, p=2, dim=-1)
return features.squeeze(0).cpu().numpy().astype("float32")
except Exception as e:
print(f"Error extracting features: {e}")
return None
def retrieve_similar_captions(image_embedding, k=5):
"""
Retrieve k most similar captions using FAISS index.
Input: Image embedding (numpy array).
Output: List of captions.
"""
if image_embedding.ndim == 1:
image_embedding = image_embedding.reshape(1, -1)
D, I = faiss_index.search(image_embedding, k)
return [captions[i] for i in I[0]]
def extract_location_names(texts):
"""
Extract location names from captions using spaCy.
Input: List of captions.
Output: List of unique location names.
"""
names = []
for text in texts:
doc = nlp(text)
for ent in doc.ents:
if ent.label_ in ["GPE", "LOC", "FAC"]:
names.append(ent.text)
return list(set(names))
def generate_caption_from_retrieved(retrieved_captions):
"""
Generate a caption from retrieved captions using T5.
Input: List of retrieved captions.
Output: Generated caption (str).
"""
locations = extract_location_names(retrieved_captions)
location_hint = f"The place might be: {', '.join(locations)}. " if locations else ""
prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = generator.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=300,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_rag_caption(image):
"""
Generate a RAG-based caption for an image.
Input: PIL Image or image path (str).
Output: Caption (str).
"""
try:
start = time.time()
embedding = extract_image_features(image)
print("⏱️ CLIP feature extraction:", round(time.time() - start, 2), "s")
if embedding is None:
return "Failed to process image."
start = time.time()
retrieved = retrieve_similar_captions(embedding, k=5)
print("⏱️ Caption retrieval:", round(time.time() - start, 2), "s")
if not retrieved:
return "No similar captions found."
start = time.time()
caption = generate_caption_from_retrieved(retrieved)
print("⏱️ Caption generation:", round(time.time() - start, 2), "s")
return caption
except Exception as e:
print(f"Error in RAG captioning: {e}")
return "Something went wrong during caption generation."