clip-embedding / app.py
nitinsurya's picture
Using Facebook metaclip model instead of sentence transformers clip.
04bef66
import gradio as gr
import numpy as np
from PIL import Image
from pathlib import Path
import torch
from transformers import CLIPProcessor, CLIPModel
MODEL_NAME = "facebook/metaclip-b32-400m"
cache_path = Path('/app/cache')
if not cache_path.exists():
cache_path = None
def get_clip_model_and_processor(model_name: str, cache_path: Path = None):
device = "cuda" if torch.cuda.is_available() else "cpu"
if cache_path:
model = CLIPModel.from_pretrained(model_name, cache_dir=str(cache_path)).to(device)
processor = CLIPProcessor.from_pretrained(model_name, cache_dir=str(cache_path))
else:
model = CLIPModel.from_pretrained(model_name).to(device)
processor = CLIPProcessor.from_pretrained(model_name)
return model.eval(), processor
def image_to_embedding(img: np.ndarray = None, txt: str = None) -> np.ndarray:
if img is None and not txt:
return []
if img is not None:
embedding = CLIP_MODEL.get_image_features(
**CLIP_PROCESSOR(images=[Image.fromarray(img)], return_tensors="pt", padding=True).to(
CLIP_MODEL.device
)
)
else:
embedding = CLIP_MODEL.get_text_features(
**CLIP_PROCESSOR(text=[txt], return_tensors="pt", padding=True).to(
CLIP_MODEL.device
)
)
return embedding.detach().cpu().numpy()
CLIP_MODEL, CLIP_PROCESSOR = get_clip_model_and_processor(MODEL_NAME, cache_path=cache_path)
demo = gr.Interface(fn=image_to_embedding, inputs=["image", "textbox"], outputs="textbox", cache_examples=True)
demo.launch(server_name="0.0.0.0")