Spaces:
Paused
Paused
File size: 5,940 Bytes
c964d4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from network import pvt_cls as TCN
import gradio as gr
def demo(img_path):
# config
batch_size = 8
crop_size = 256
model_path = '/users/k21163430/workspace/TreeFormer/models/best_model.pth'
device = torch.device('cuda')
# prepare model
model = TCN.pvt_treeformer(pretrained=False)
model.to(device)
model.load_state_dict(torch.load(model_path, device))
model.eval()
# preprocess
img = Image.open(img_path).convert('RGB')
show_img = np.array(img)
wd, ht = img.size
st_size = 1.0 * min(wd, ht)
if st_size < crop_size:
rr = 1.0 * crop_size / st_size
wd = round(wd * rr)
ht = round(ht * rr)
st_size = 1.0 * min(wd, ht)
img = img.resize((wd, ht), Image.BICUBIC)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = transform(img)
img = img.unsqueeze(0)
# model forward
with torch.no_grad():
inputs = img.to(device)
crop_imgs, crop_masks = [], []
b, c, h, w = inputs.size()
rh, rw = crop_size, crop_size
for i in range(0, h, rh):
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
for j in range(0, w, rw):
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
mask = torch.zeros([b, 1, h, w]).to(device)
mask[:, :, gis:gie, gjs:gje].fill_(1.0)
crop_masks.append(mask)
crop_imgs, crop_masks = map(lambda x: torch.cat(
x, dim=0), (crop_imgs, crop_masks))
crop_preds = []
nz, bz = crop_imgs.size(0), batch_size
for i in range(0, nz, bz):
gs, gt = i, min(nz, i + bz)
crop_pred, _ = model(crop_imgs[gs:gt])
crop_pred = crop_pred[0]
_, _, h1, w1 = crop_pred.size()
crop_pred = F.interpolate(crop_pred, size=(
h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
crop_preds.append(crop_pred)
crop_preds = torch.cat(crop_preds, dim=0)
# splice them to the original size
idx = 0
pred_map = torch.zeros([b, 1, h, w]).to(device)
for i in range(0, h, rh):
gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
for j in range(0, w, rw):
gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
idx += 1
# for the overlapping area, compute average value
mask = crop_masks.sum(dim=0).unsqueeze(0)
outputs = pred_map / mask
outputs = F.interpolate(outputs, size=(
h, w), mode='bilinear', align_corners=True)/4
outputs = pred_map / mask
model_output = round(torch.sum(outputs).item())
print("{}: {}".format(img_path, model_output))
outputs = outputs.squeeze().cpu().numpy()
outputs = (outputs - np.min(outputs)) / \
(np.max(outputs) - np.min(outputs))
show_img = show_img / 255.0
show_img = show_img * 0.2 + outputs[:, :, None] * 0.8
return model_output, show_img
if __name__ == "__main__":
# test
# img_path = sys.argv[1]
# demo(img)
# Launch a gr.Interface
gr_demo = gr.Interface(fn=demo,
inputs=gr.Image(source="upload",
type="filepath",
label="Input Image",
width=768,
height=768,
),
outputs=[
gr.Number(label="Predicted Tree Count"),
gr.Image(label="Density Map",
width=768,
height=768,
)
],
title="TreeFormer",
description="TreeFormer is a semi-supervised transformer-based framework for tree counting from a single high resolution image. Upload an image and TreeFormer will predict the number of trees in the image and generate a density map of the trees.",
article="This work has been developed a spart of the ReSET project which has received funding from the European Union's Horizon 2020 FET Proactive Programme under grant agreement No 101017857. The contents of this publication are the sole responsibility of the ReSET consortium and do not necessarily reflect the opinion of the European Union.",
examples=[
["./examples/IMG_101.jpg"],
["./examples/IMG_125.jpg"],
["./examples/IMG_138.jpg"],
["./examples/IMG_180.jpg"],
["./examples/IMG_18.jpg"],
["./examples/IMG_206.jpg"],
["./examples/IMG_223.jpg"],
["./examples/IMG_247.jpg"],
["./examples/IMG_270.jpg"],
["./examples/IMG_306.jpg"],
],
# cache_examples=True,
examples_per_page=10,
allow_flagging=False,
theme=gr.themes.Default(),
)
gr_demo.launch(share=True, server_port=7861, favicon_path="./assets/reset.png")
|