Silverwing123's picture
initial commit
2b95373 verified
import gradio as gr
from PIL import Image
# import torchvision
import os
from os.path import join
import utils
import warnings
warnings.filterwarnings("ignore")
model = utils.load_model("faster_rcnn_fold_3_150.pth")
example_path = join("example")
path = [
join(example_path, "1.jpg"),
join(example_path, "2.jpg")
]
def show_preds_image(image_path):
# open image
im = Image.open(image_path).convert("RGB")
# transform image
img_transformed, scale = utils.transform_img(im)
# pred img
# {boxes: [[int, int, int, int], ...], lables: [string, ...], scores: [int, ...] }
pred_img = utils.predict_image(model, img_transformed, scale)
output_img = utils.create_img_pred(image_path, pred_img)
return output_img
inputs_image = [
gr.components.Image(type="filepath", label="Input Image")
]
outputs_image = [
gr.components.Image(type="numpy", label="Output Image")
]
interface_image = gr.Interface(
fn=show_preds_image,
inputs=inputs_image,
outputs=outputs_image,
title="weld Defect Detection",
examples=path,
cache_examples=False
).launch()
# demo = gr.Interface(fn=greet, inputs="text", outputs="text")