|
|
import torch
|
|
|
from PIL import Image
|
|
|
import matplotlib.pyplot as plt
|
|
|
from torchvision import transforms
|
|
|
import gradio as gr
|
|
|
from models.birefnet import BiRefNet
|
|
|
import io
|
|
|
import tempfile
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
print("Используем устройство:", device)
|
|
|
|
|
|
|
|
|
MODEL_CONFIG = {
|
|
|
"BiRefNet_HR": {
|
|
|
"repo": "ZhengPeng7/BiRefNet_HR",
|
|
|
"image_size": (2048, 2048)
|
|
|
},
|
|
|
"BiRefNet": {
|
|
|
"repo": "ZhengPeng7/BiRefNet",
|
|
|
"image_size": (1024, 1024)
|
|
|
},
|
|
|
"BiRefNet-matting": {
|
|
|
"repo": "ZhengPeng7/BiRefNet-matting",
|
|
|
"image_size": (1024, 1024)
|
|
|
},
|
|
|
"BiRefNet-portrait": {
|
|
|
"repo": "ZhengPeng7/BiRefNet-portrait",
|
|
|
"image_size": (1024, 1024)
|
|
|
},
|
|
|
"BiRefNet-HRSOD": {
|
|
|
"repo": "ZhengPeng7/BiRefNet-HRSOD",
|
|
|
"image_size": (1024, 1024)
|
|
|
},
|
|
|
}
|
|
|
|
|
|
|
|
|
loaded_models = {}
|
|
|
|
|
|
def load_model(model_name):
|
|
|
if model_name not in loaded_models:
|
|
|
print(f"Загрузка модели {model_name}...")
|
|
|
model = BiRefNet.from_pretrained(MODEL_CONFIG[model_name]["repo"])
|
|
|
model.to(device).eval()
|
|
|
if device == 'cuda':
|
|
|
model.half()
|
|
|
loaded_models[model_name] = model
|
|
|
return loaded_models[model_name]
|
|
|
|
|
|
def extract_object(image, model_name):
|
|
|
|
|
|
model = load_model(model_name)
|
|
|
config = MODEL_CONFIG[model_name]
|
|
|
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
transforms.Resize(config["image_size"]),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
|
])
|
|
|
|
|
|
input_tensor = transform(image).unsqueeze(0)
|
|
|
input_tensor = input_tensor.to(device)
|
|
|
if device == 'cuda':
|
|
|
input_tensor = input_tensor.half()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
preds = model(input_tensor)[-1].sigmoid().cpu()
|
|
|
|
|
|
mask = transforms.ToPILImage()(preds[0].squeeze())
|
|
|
mask = mask.resize(image.size)
|
|
|
|
|
|
result = image.copy()
|
|
|
result.putalpha(mask)
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
|
|
|
result.save(tmp_file)
|
|
|
temp_filepath = tmp_file.name
|
|
|
|
|
|
return temp_filepath
|
|
|
|
|
|
|
|
|
iface = gr.Interface(
|
|
|
fn=extract_object,
|
|
|
inputs=[
|
|
|
gr.Image(type="pil", label="Входное изображение"),
|
|
|
gr.Dropdown(
|
|
|
choices=list(MODEL_CONFIG.keys()),
|
|
|
value="BiRefNet_HR",
|
|
|
label="Выбор модели"
|
|
|
)
|
|
|
],
|
|
|
outputs=gr.Image(type="filepath", label="Результат"),
|
|
|
title="BiRefNet - Интерактивная сегментация",
|
|
|
description=(
|
|
|
"Выберите модель и загрузите изображение для сегментации. "
|
|
|
"Доступные модели: BiRefNet_HR (2048x2048), BiRefNet (1024x1024), BiRefNet-lite-2K (2048x2048)"
|
|
|
),
|
|
|
allow_flagging="never"
|
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
iface.launch(share=True) |