Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw | |
import requests | |
from transformers import SamModel, SamProcessor | |
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
import cv2 | |
from typing import List | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
#Load clipseg Model | |
clip_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
clip_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device) | |
# Load SAM model and processor | |
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
cache_data = None | |
# Prompts to segment damaged area and car | |
prompts = ['damaged', 'car'] | |
damage_threshold = 0.3 | |
vehicle_threshold = 0.5 | |
def bbox_normalization(bbox, width, height): | |
height_coeff = height/352 | |
width_coeff = width/352 | |
normalized_bbox = [[bbox[0]*width_coeff, bbox[1]*height_coeff], | |
[bbox[2]*width_coeff, bbox[3]*height_coeff]] | |
print(f'Normalized-bbox:: {normalized_bbox}') | |
return normalized_bbox | |
def bbox_area(bbox): | |
area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) | |
return area | |
def segment_to_bbox(segment_indexs): | |
x_points = [] | |
y_points = [] | |
for y, list_val in enumerate(segment_indexs): | |
for x, val in enumerate(list_val): | |
if val == 1: | |
x_points.append(x) | |
y_points.append(y) | |
if x_points and y_points: | |
return [np.min(x_points), np.min(y_points), np.max(x_points), np.max(y_points)] | |
else: | |
return [0.0,0.0,0.0,0.0] | |
def clipseg_prediction(image): | |
print('Clip-Segmentation-started------->') | |
img_w, img_h,_ = image.shape | |
inputs = clip_processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt") | |
# predict | |
with torch.no_grad(): | |
outputs = clip_model(**inputs) | |
preds = outputs.logits.unsqueeze(1) | |
# Setting threshold and classify the image contains vehicle or not | |
flat_preds = torch.sigmoid(preds.squeeze()).reshape((preds.shape[0], -1)) | |
# Initialize a dummy "unlabeled" mask with the threshold | |
flat_damage_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), damage_threshold) | |
flat_vehicle_preds_with_treshold = torch.full((2, flat_preds.shape[-1]), vehicle_threshold) | |
flat_damage_preds_with_treshold[1:2,:] = flat_preds[0] # damage | |
flat_vehicle_preds_with_treshold[1:2,:] = flat_preds[1] # vehicle | |
# Get the top mask index for each pixel | |
damage_inds = torch.topk(flat_damage_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1])) | |
vehicle_inds = torch.topk(flat_vehicle_preds_with_treshold, 1, dim=0).indices.reshape((preds.shape[-2], preds.shape[-1])) | |
# bbox creation | |
damage_bbox = segment_to_bbox(damage_inds) | |
vehicle_bbox = segment_to_bbox(vehicle_inds) | |
# Vehicle checking | |
if bbox_area(vehicle_bbox) > bbox_area(damage_bbox): | |
return True, [bbox_normalization(damage_bbox, img_w, img_h)] | |
else: | |
return False, [[]] | |
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray: | |
print('SAM-Segmentation-started------->') | |
global cache_data | |
image_input = Image.fromarray(image_input) | |
inputs = processor(image_input, input_boxes=points, return_tensors="pt").to(device) | |
if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]): | |
embedding = model.get_image_embeddings(inputs["pixel_values"]) | |
pixels = inputs["pixel_values"] | |
cache_data = [pixels, embedding] | |
del inputs["pixel_values"] | |
outputs = model.forward(image_embeddings=cache_data[1], **inputs) | |
# outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), inputs["original_sizes"].to(device), inputs["reshaped_input_sizes"].to(device) | |
) | |
masks = masks[0][0].squeeze(0).numpy() | |
return masks | |
def main_func(inputs): | |
image_input = inputs | |
classification, points = clipseg_prediction(image_input) | |
if classification: | |
masks = foward_pass(image_input, points) | |
# image_input = Image.fromarray(image_input) | |
final_mask = masks[0] | |
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8) | |
mask_colors[final_mask, :] = np.array([[256, 0, 0]]) | |
return 'Prediction: Vehicle damage prediction is given.',Image.fromarray((mask_colors+ image_input).astype('uint8'), 'RGB') | |
else: | |
print('Prediction:: No vehicle found in the image') | |
return 'Prediction:: No vehicle or damage found in the image',Image.fromarray(image_input) | |
def reset_data(): | |
global cache_data | |
cache_data = None | |
with gr.Blocks() as demo: | |
gr.Markdown("# Vehicle damage detection") | |
gr.Markdown("""This app uses the SAM model and clipseg model to get a vehicle damage area from image.""") | |
with gr.Row(): | |
image_input = gr.Image(label='Input Image') | |
image_output = gr.Image(label='Damage Detection') | |
with gr.Row(): | |
examples = gr.Examples(examples="./examples", inputs=image_input) | |
prediction_op = gr.gradio.Textbox(label='Prediction') | |
image_button = gr.Button("Segment Image", variant='primary') | |
image_button.click(main_func, inputs=image_input, outputs=[prediction_op, image_output]) | |
image_input.upload(reset_data) | |
demo.launch() | |