youl's picture
application
4655008
raw
history blame
3.29 kB
import gradio as gr
import albumentations as A
from functions import *
warnings.filterwarnings('ignore')
# transform image
test_transforms = A.Compose([
A.Resize(height=1024, width=1024, always_apply=True),
A.Normalize(always_apply=True),
ToTensorV2(always_apply=True),])
# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# model loading
model = torch.load('pickel.pth',map_location=torch.device('cpu'))
model = model.to(device)
#-> Tuple[Dict, float]
def predict(img) :
# Start a timer
start_time = timer()
image = np.array(img)
h,w,_ = image.shape
hw = h*w
if hw < 2*1024*1024:
# Transform the target image and add a batch dimension
#image_transformed = test_transforms()
transformed = test_transforms(image= image)
image_transformed = transformed["image"]
image_transformed = image_transformed.unsqueeze(0)
image_transformed = image_transformed.to(device)
# inference
model.eval()
with torch.no_grad():
predictions = model(image_transformed)[0]
nms_prediction = apply_nms(predictions, iou_thresh=0.1)
pred = plot_img_bbox(image, nms_prediction)
#pred = np.array(Image.open("pred.jpg"))
word = "Number of palm trees detected : "+str(len(nms_prediction["boxes"]))
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred,word
else:
crop(image)
locations = np.load("locations.npy")
n = inference(image,locations,model,test_transforms,device)
#
empty_image = np.zeros(image.shape)
del image
gc.collect()
sleep(1)
word = "Number of palm trees detected : "+str(n)
pred = create_new_ortho(locations,empty_image)
# remove files and folders
os.remove("locations.npy")
shutil.rmtree("images", ignore_errors=True)
shutil.rmtree("labels", ignore_errors=True)
return pred,word
image = gr.components.Image()
out_im = gr.components.Image()
out_lab = gr.components.Label()
### 4. Gradio app ###
# Create title, description and article strings
title = "🌴Palm trees detection🌴"
description = "Faster r-cnn model to detect oil palm trees in drones images."
article = "Created by data354."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
#[gr.Label(label="Predictions"), # what are the outputs?
#gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs= image, #gr.Image(type="pil"), # what are the inputs?
outputs=[out_im,out_lab],
examples=example_list,
title=title,
description=description,
article=article
)
# Launch the demo!
demo.launch(debug = False)