SparseAGS / process.py
qitaoz's picture
init commit
26ce2a9 verified
import os
import glob
import sys
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import rembg
os.environ["OMP_NUM_THREADS"] = "10"
class BLIP2():
def __init__(self, device='cuda'):
self.device = device
from transformers import AutoProcessor, Blip2ForConditionalGeneration
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
@torch.no_grad()
def __call__(self, image):
image = Image.fromarray(image)
inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs, max_new_tokens=20)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return generated_text
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models")
parser.add_argument('--size', default=256, type=int, help="output resolution")
parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio")
parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123")
opt = parser.parse_args()
session = rembg.new_session(model_name=opt.model)
if os.path.isdir(opt.path):
print(f'[INFO] processing directory {opt.path}...')
files = glob.glob(f'{opt.path}/*')
out_dir = opt.path
else: # isfile
files = [opt.path]
out_dir = os.path.dirname(opt.path)
os.makedirs(os.path.join(out_dir, 'processed'), exist_ok=True)
os.makedirs(os.path.join(out_dir, 'source'), exist_ok=True)
for file in files:
out_base = os.path.basename(file).split('.')[0]
out_rgba = os.path.join(out_dir, 'processed', out_base + '_rgba.png')
out_rgb = os.path.join(out_dir, 'source', out_base + '.png')
# load image
print(f'[INFO] loading image {file}...')
image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
# carve background
print(f'[INFO] background removal...')
carved_image = rembg.remove(image, session=session) # [H, W, 4]
mask = carved_image[..., -1] > 0
# recenter
if opt.recenter:
print(f'[INFO] recenter...')
final_rgb = np.zeros((opt.size, opt.size, 3), dtype=np.uint8)
final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
coords = np.nonzero(mask)
x_min, x_max = coords[0].min(), coords[0].max()
y_min, y_max = coords[1].min(), coords[1].max()
h = x_max - x_min
w = y_max - y_min
desired_size = int(opt.size * (1 - opt.border_ratio))
scale = desired_size / max(h, w)
h2 = int(h * scale)
w2 = int(w * scale)
x2_min = (opt.size - h2) // 2
x2_max = x2_min + h2
y2_min = (opt.size - w2) // 2
y2_max = y2_min + w2
final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
xc = (x_min + x_max) // 2
yc = (y_min + y_max) // 2
l = int(max(h, w) / (1 - opt.border_ratio)) // 2
x_min, x_max = xc - l, xc + l
y_min, y_max = yc - l, yc + l
H, W = image.shape[:2]
# pad the image in case the bbox is outside of boundary
canvas = np.zeros((max(H, x_max) - min(0, x_min), max(W, y_max) - min(0, y_min), 3), dtype=image.dtype)
# calculate where to place the original image on the canvas
y_offset = -min(0, y_min)
x_offset = -min(0, x_min)
canvas[x_offset:x_offset + H, y_offset:y_offset + W] = image
# extract the region from the padded canvas
roi = canvas[x_offset + x_min:x_offset + x_max, y_offset + y_min:y_offset + y_max]
final_rgb = cv2.resize(roi, (opt.size, opt.size), interpolation=cv2.INTER_AREA)
else:
final_rgba = carved_image
# write image
cv2.imwrite(out_rgba, final_rgba)
cv2.imwrite(out_rgb, final_rgb)