pmf_with_gis / app.py
hushell's picture
fix _curses.error: setupterm: could not find terminal
e337a53
raw history blame
No virus
5.48 kB
import os
import numpy as np
import time
import random
import torch
import torchvision.transforms as transforms
#import requests
import gradio as gr
import matplotlib.pyplot as plt
from models import get_model
from dotmap import DotMap
from PIL import Image
os.environ['TERM'] = 'linux'
os.environ['TERMINFO'] = '/etc/terminfo'
# args
args = DotMap()
args.deploy = 'vanilla'
args.arch = 'dino_small_patch16'
#args.resume = '/fast_scratch/hushell/fluidstack/FS125_few-shot-transformer/outputs/dinosmall_1e-4/best_converted.pth'
args.resume = 'https://huggingface.co/hushell/pmf_dinosmall_lr1e-4/resolve/main/best_converted.pth'
args.api_key = 'AIzaSyAFkOGnXhy-2ZB0imDvNNqf2rHb98vR_qY'
args.cx = '06d75168141bc47f1'
# model
device = 'cpu' #torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(args)
model.to(device)
#checkpoint = torch.load(args.resume, map_location='cpu')
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=True)
# image transforms
def test_transform():
def _convert_image_to_rgb(im):
return im.convert('RGB')
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
_convert_image_to_rgb,
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
preprocess = test_transform()
@torch.no_grad()
def denormalize(x, mean, std):
# 3, H, W
t = x.clone()
t.mul_(std).add_(mean)
return torch.clamp(t, 0, 1)
# Google image search
from google_images_search import GoogleImagesSearch
# define search params
# option for commonly used search param are shown below for easy reference.
# For param marked with '##':
# - Multiselect is currently not feasible. Choose ONE option only
# - This param can also be omitted from _search_params if you do not wish to define any value
_search_params = {
'q': '...',
'num': 10,
'fileType': 'png', #'jpg|gif|png',
'rights': 'cc_publicdomain', #'cc_publicdomain|cc_attribute|cc_sharealike|cc_noncommercial|cc_nonderived',
#'safe': 'active|high|medium|off|safeUndefined', ##
'imgType': 'photo', #'clipart|face|lineart|stock|photo|animated|imgTypeUndefined', ##
#'imgSize': 'huge|icon|large|medium|small|xlarge|xxlarge|imgSizeUndefined', ##
#'imgDominantColor': 'black|blue|brown|gray|green|orange|pink|purple|red|teal|white|yellow|imgDominantColorUndefined', ##
'imgColorType': 'color', #'color|gray|mono|trans|imgColorTypeUndefined' ##
}
# Gradio UI
def inference(query, labels, n_supp=10):
'''
query: PIL image
labels: list of class names
'''
labels = labels.split(',')
n_supp = int(n_supp)
#print(f'#rows={len(labels)}, #cols={n_supp}')
fig, axs = plt.subplots(len(labels), n_supp, figsize=(n_supp*4, len(labels)*4))
with torch.no_grad():
# query image
query = preprocess(query).unsqueeze(0).unsqueeze(0).to(device) # (1, 1, 3, H, W)
supp_x = []
supp_y = []
# search support images
for idx, y in enumerate(labels):
with GoogleImagesSearch(args.api_key, args.cx) as gis:
_search_params['q'] = y
_search_params['num'] = n_supp
gis.search(search_params=_search_params, custom_image_name='my_image')
gis._custom_image_name = 'my_image'
for j, x in enumerate(gis.results()):
#url = x.url
#x_im = Image.open(requests.get(url, stream=True).raw)
x.download('./')
x_im = Image.open(x.path)
# vis
axs[idx, j].imshow(x_im)
axs[idx, j].set_title(f'{y}{j}')
axs[idx, j].axis('off')
x_im = preprocess(x_im) # (3, H, W)
supp_x.append(x_im)
supp_y.append(idx)
print('Searching for support images is done.')
supp_x = torch.stack(supp_x, dim=0).unsqueeze(0).to(device) # (1, n_supp*n_labels, 3, H, W)
supp_y = torch.tensor(supp_y).long().unsqueeze(0).to(device) # (1, n_supp*n_labels)
with torch.cuda.amp.autocast(True):
output = model(supp_x, supp_y, query) # (1, 1, n_labels)
probs = output.softmax(dim=-1).detach().cpu().numpy()
return {k: float(v) for k, v in zip(labels, probs[0, 0])}, fig
# DEBUG
#query = Image.open('../labrador-puppy.jpg')
##labels = 'dog, cat'
#labels = 'girl, boy'
#output = inference(query, labels, n_supp=2)
#print(output)
gr.Interface(fn=inference,
inputs=[
gr.inputs.Image(label="Image to classify", type="pil"),
gr.inputs.Textbox(lines=1, label="Class hypotheses:", placeholder="Enter class names separated by ','",),
#gr.inputs.Number(default=1, label="Number of support examples from Google")
gr.inputs.Slider(minimum=2, maximum=10, step=1, label="Number of support examples from Google")
],
theme="grass",
outputs=[
gr.outputs.Label(label="Predicted class probabilities"),
gr.outputs.Image(type='plot', label="Support examples from Google image search"),
],
description="PMF few-shot learning with Google image search").launch()