library_name: birefnet
tags:
- background-removal
- mask-generation
- Dichotomous Image Segmentation
- Camouflaged Object Detection
- Salient Object Detection
- pytorch_model_hub_mixin
- model_hub_mixin
repo_url: https://github.com/ZhengPeng7/BiRefNet
pipeline_tag: image-segmentation
Bilateral Reference for High-Resolution Dichotomous Image Segmentation
DIS-Sample_1 | DIS-Sample_2 |
---|---|
This repo is the official implementation of "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (arXiv 2024).
Visit our GitHub repo: https://github.com/ZhengPeng7/BiRefNet for more details -- codes, docs, and model zoo!
How to use
1. Load BiRefNet:
Use codes + weights from HuggingFace
Only use the weights on HuggingFace -- Pro: No need to download BiRefNet codes manually; Con: Codes on HuggingFace might not be latest version (I'll try to keep them always latest).
# Load BiRefNet with weights
from transformers import AutoModelForImageSegmentation
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/birefnet', trust_remote_code=True)
Use codes from GitHub + weights from HuggingFace
Only use the weights on HuggingFace -- Pro: codes are always latest; Con: Need to clone the BiRefNet repo from my GitHub.
# Download codes
git clone https://github.com/ZhengPeng7/BiRefNet.git
cd BiRefNet
# Load weights
from models.birefnet import BiRefNet
# Option-1: From Hugging Face Models
birefnet = BiRefNet.from_pretrained('zhengpeng7/birefnet')
# Option-2: From local disk
import torch
from utils import check_state_dict
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
Use the loaded BiRefNet for inference
# Imports
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from models.birefnet import BiRefNet
birefnet = ... # -- BiRefNet should be loaded with codes above, either way.
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to('cuda')
birefnet.eval()
def extract_object(birefnet, imagepath):
# Data settings
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(imagepath)
input_images = transform_image(image).unsqueeze(0).to('cuda')
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
return image, mask
# Visualization
plt.axis("off")
plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
plt.show()
This BiRefNet for standard dichotomous image segmentation (DIS) is trained on DIS-TR and validated on DIS-TEs and DIS-VD.
This repo holds the official model weights of "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (arXiv 2024).
This repo contains the weights of BiRefNet proposed in our paper, which has achieved the SOTA performance on three tasks (DIS, HRSOD, and COD).
Go to my GitHub page for BiRefNet codes and the latest updates: https://github.com/ZhengPeng7/BiRefNet :)
Try our online demos for inference:
- Online Single Image Inference on Colab:
- Online Inference with GUI on Hugging Face with adjustable resolutions:
- Inference and evaluation of your given weights:
Acknowledgement:
- Many thanks to @fal for their generous support on GPU resources for training better BiRefNet models.
- Many thanks to @not-lain for his help on the better deployment of our BiRefNet model on HuggingFace.
Citation
@article{zheng2024birefnet,
title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
journal={arXiv},
year={2024}
}