from PIL import Image | |
import requests | |
from io import BytesIO | |
import torch | |
from transformers import AutoModel, AutoProcessor, AutoConfig, AutoModelForVision2Seq | |
# from granite_cola import ColGraniteVisionConfig, ColGraniteVision, ColGraniteVisionProcessor | |
# --- 1) Register your custom classes so AutoModel/AutoProcessor work out-of-the-box | |
# AutoConfig.register("colgranitevision", ColGraniteVisionConfig) | |
# AutoModel.register(ColGraniteVisionConfig, ColGraniteVision) | |
# AutoProcessor.register(ColGraniteVisionConfig, ColGraniteVisionProcessor) | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# 2) Load model & processor | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
model_dir = "." | |
model = AutoModelForVision2Seq.from_pretrained( | |
model_dir, | |
trust_remote_code=True, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
) | |
# self.model = PeftModel.from_pretrained(self.model, peft_path).eval() | |
processor = AutoProcessor.from_pretrained( | |
model_dir, | |
trust_remote_code=True, | |
use_fast=True | |
) | |
# Set patch_size explicitly if needed | |
if hasattr(processor, 'patch_size') and processor.patch_size is None: | |
processor.patch_size = 14 # Default patch size for vision transformers | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device).eval() | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# 3) Download sample image + build a prompt containing <image> | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
image_url = "https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg" | |
resp = requests.get(image_url) | |
image = Image.open(BytesIO(resp.content)).convert("RGB") | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# 4) Process image and text | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# Process image | |
image_inputs = processor.process_images([image]) | |
image_inputs = {k: v.to(device) for k, v in image_inputs.items()} | |
# Process text | |
text = "A photo of a tiger" | |
text_inputs = processor.process_queries([text]) | |
text_inputs = {k: v.to(device) for k, v in text_inputs.items()} | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
# 5) Get embeddings and score | |
# βββββββββββββββββββββββββββββββββββββββββββββ | |
with torch.no_grad(): | |
# Get image embedding | |
image_embedding = model(**image_inputs) | |
# Get text embedding | |
text_embedding = model(**text_inputs) | |
# Calculate similarity score | |
score = torch.matmul(text_embedding, image_embedding.T).item() | |
print(f"Similarity score between text and image: {score:.4f}") | |