Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
MODEL_NAME = "facebook/metaclip-b32-400m" | |
cache_path = Path('/app/cache') | |
if not cache_path.exists(): | |
cache_path = None | |
def get_clip_model_and_processor(model_name: str, cache_path: Path = None): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if cache_path: | |
model = CLIPModel.from_pretrained(model_name, cache_dir=str(cache_path)).to(device) | |
processor = CLIPProcessor.from_pretrained(model_name, cache_dir=str(cache_path)) | |
else: | |
model = CLIPModel.from_pretrained(model_name).to(device) | |
processor = CLIPProcessor.from_pretrained(model_name) | |
return model.eval(), processor | |
def image_to_embedding(img: np.ndarray = None, txt: str = None) -> np.ndarray: | |
if img is None and not txt: | |
return [] | |
if img is not None: | |
embedding = CLIP_MODEL.get_image_features( | |
**CLIP_PROCESSOR(images=[Image.fromarray(img)], return_tensors="pt", padding=True).to( | |
CLIP_MODEL.device | |
) | |
) | |
else: | |
embedding = CLIP_MODEL.get_text_features( | |
**CLIP_PROCESSOR(text=[txt], return_tensors="pt", padding=True).to( | |
CLIP_MODEL.device | |
) | |
) | |
return embedding.detach().cpu().numpy() | |
CLIP_MODEL, CLIP_PROCESSOR = get_clip_model_and_processor(MODEL_NAME, cache_path=cache_path) | |
demo = gr.Interface(fn=image_to_embedding, inputs=["image", "textbox"], outputs="textbox", cache_examples=True) | |
demo.launch(server_name="0.0.0.0") | |