LayBraid
add requirements.txt
21884ee
raw
history blame
1.47 kB
import gradio as gr
import os
import torch
from torchvision import transforms
from PIL import Image
from torchvision.datasets import CIFAR100
from transformers import CLIPProcessor, CLIPModel
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
IMG_SIZE = 32 if torch.cuda.is_available() else 32
COMPOSED_TRANSFORMERS = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
])
NORMALIZE_TENSOR = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def np_array_to_tensor_image(img, width=IMG_SIZE, height=IMG_SIZE, device='cpu'):
image = Image.fromarray(img).convert('RGB').resize((width, height))
image = COMPOSED_TRANSFORMERS(image).unsqueeze(0)
return image.to(device, torch.float)
def normalize_tensor(tensor: torch.tensor) -> torch.tensor:
return NORMALIZE_TENSOR(tensor)
def send_inputs(img):
##img = np_array_to_tensor_image(img)
##img = normalize_tensor(img)
inputs = processor(images=img, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
print(probs)
return probs
if __name__ == "__main__":
gr.Interface(fn=send_inputs, inputs=["image"], outputs="text").launch()