Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from huggingface_hub import hf_hub_download | |
| from torch.autograd import Variable | |
| from PIL import Image | |
| from transformers import pipeline | |
| captioner = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| prompt="The main subject of this picture is a" | |
| ) | |
| def removeBackground(image): | |
| image = Image.fromarray(image).convert("RGB") | |
| caption = captioner(image, max_new_tokens=20)[0]["generated_text"].lower().replace("The main subject of this picture is a".lower(), "").strip() | |
| if not os.path.exists("saved_models"): | |
| os.mkdir("saved_models") | |
| os.mkdir("git") | |
| os.system("git clone https://github.com/xuebinqin/DIS git/xuebinqin/DIS") | |
| hf_hub_download( | |
| repo_id="NimaBoscarino/IS-Net_DIS-general-use", | |
| filename="isnet-general-use.pth", | |
| local_dir="saved_models", | |
| ) | |
| os.system("rm -r git/xuebinqin/DIS/IS-Net/__pycache__") | |
| os.system("mv git/xuebinqin/DIS/IS-Net/* .") | |
| def build_model(hypar, device): | |
| net = hypar["model"] # GOSNETINC(3,1) | |
| # convert to half precision | |
| if hypar["model_digit"] == "half": | |
| net.half() | |
| for layer in net.modules(): | |
| if isinstance(layer, nn.BatchNorm2d): | |
| layer.float() | |
| net.to(device) | |
| if hypar["restore_model"] != "": | |
| net.load_state_dict( | |
| torch.load( | |
| hypar["model_path"] + "/" + hypar["restore_model"], | |
| map_location=device, | |
| ) | |
| ) | |
| net.to(device) | |
| net.eval() | |
| return net | |
| import data_loader_cache | |
| import models | |
| device = "cpu" | |
| ISNetDIS = models.ISNetDIS | |
| normalize = data_loader_cache.normalize | |
| im_preprocess = data_loader_cache.im_preprocess | |
| # Set Parameters | |
| hypar = {} # paramters for inferencing | |
| # load trained weights from this path | |
| hypar["model_path"] = "./saved_models" | |
| # name of the to-be-loaded weights | |
| hypar["restore_model"] = "isnet-general-use.pth" | |
| # indicate if activate intermediate feature supervision | |
| hypar["interm_sup"] = False | |
| # choose floating point accuracy -- | |
| # indicates "half" or "full" accuracy of float number | |
| hypar["model_digit"] = "full" | |
| hypar["seed"] = 0 | |
| # cached input spatial resolution, can be configured into different size | |
| hypar["cache_size"] = [1024, 1024] | |
| # data augmentation parameters --- | |
| # mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images | |
| hypar["input_size"] = [1024, 1024] | |
| # random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation | |
| hypar["crop_size"] = [1024, 1024] | |
| hypar["model"] = ISNetDIS() | |
| # Build Model | |
| net = build_model(hypar, device) | |
| def predict(net, inputs_val, shapes_val, hypar, device): | |
| """ | |
| Given an Image, predict the mask | |
| """ | |
| net.eval() | |
| if hypar["model_digit"] == "full": | |
| inputs_val = inputs_val.type(torch.FloatTensor) | |
| else: | |
| inputs_val = inputs_val.type(torch.HalfTensor) | |
| inputs_val_v = Variable(inputs_val, requires_grad=False).to( | |
| device | |
| ) # wrap inputs in Variable | |
| ds_val = net(inputs_val_v)[0] # list of 6 results | |
| # B x 1 x H x W # we want the first one which is the most accurate prediction | |
| pred_val = ds_val[0][0, :, :, :] | |
| # recover the prediction spatial size to the orignal image size | |
| pred_val = torch.squeeze( | |
| F.upsample( | |
| torch.unsqueeze(pred_val, 0), | |
| (shapes_val[0][0], shapes_val[0][1]), | |
| mode="bilinear", | |
| ) | |
| ) | |
| ma = torch.max(pred_val) | |
| mi = torch.min(pred_val) | |
| pred_val = (pred_val - mi) / (ma - mi) # max = 1 | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| # it is the mask we need | |
| return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8) | |
| def load_image(im_pil, hypar): | |
| im = np.array(im_pil) | |
| im, im_shp = im_preprocess(im, hypar["cache_size"]) | |
| im = torch.divide(im, 255.0) | |
| shape = torch.from_numpy(np.array(im_shp)) | |
| # make a batch of image, shape | |
| aa = normalize(im, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
| return aa.unsqueeze(0), shape.unsqueeze(0) | |
| image_tensor, orig_size = load_image(image, hypar) | |
| mask = predict(net, image_tensor, orig_size, hypar, "cpu") | |
| mask = Image.fromarray(mask).convert("L") | |
| im_rgb = image.convert("RGB") | |
| cropped = im_rgb.copy() | |
| cropped.putalpha(mask) | |
| return cropped, caption | |
| inputs = gr.inputs.Image() | |
| outputs = [gr.outputs.Image(type="pil"), "text"] | |
| interface = gr.Interface( | |
| fn=removeBackground, | |
| inputs=inputs, | |
| outputs=outputs, | |
| title="Remove Background", | |
| description="This App removes the background from an image", | |
| examples=[ | |
| "examples/input/1.jpeg", | |
| "examples/input/2.jpeg", | |
| "examples/input/3.jpeg", | |
| ], | |
| cache_examples=True, | |
| ) | |
| interface.launch(enable_queue=True) | |