autocomp-demo / csai.py
gabgrenier
made harmonizer optional
ef88f57
from PIL import Image, ImageOps
import requests
import json
import uuid, os, time
from io import BytesIO
import numpy as np
import torch
from rvm.model import MattingNetwork
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from s3utils import *
def process(fg, bg, should_harmonize=True):
# make sure the FG is > 512 on its smallest side
if min(fg.size) < 513:
ratio = 513 / min(fg.size)
new_size = (int(fg.size[0]*ratio), int(fg.size[1]*ratio))
fg = fg.resize(new_size, Image.Resampling.LANCZOS)
# make sure the FG2 is 512 on its largest side
ratio = 513 / max(fg.size)
new_size = (int(fg.size[0]*ratio), int(fg.size[1]*ratio))
fg2 = fg.resize(new_size, Image.Resampling.LANCZOS)
# if needed, crop, resize and center the BG on the FG
if bg.size != fg.size:
print("BG image is not the same size as FG, adjusting...")
# Resize the BG to match the FG, maintaining its aspect ratio
bg = ImageOps.fit(bg, fg.size, Image.Resampling.LANCZOS)
# If the BG is still smaller than the FG (can happen due to aspect ratio differences),
# then create a new image of the FG size and paste the BG into the center of it.
if bg.size != fg.size:
new_bg = Image.new("RGBA", fg.size)
new_bg.paste(bg, ((fg.size[0] - bg.size[0]) // 2,
(fg.size[1] - bg.size[1]) // 2))
bg = new_bg
# create an ID for CSAI
unique_id = str(uuid.uuid4())
save_image_to_s3(fg, 'JPEG', unique_id + '/fg.jpg')
# process the ko with CSAI
ko = csai(unique_id)
if ko == None:
print(f'KO failed for {unique_id}')
return None
alpha_rvm = rvm(fg2)
# resize alpha to match fg and bg
alpha_rvm = alpha_rvm.resize(fg.size, Image.Resampling.LANCZOS)
alpha_np = np.array(alpha_rvm) / 255.0
# Convert 'ko' image to a numpy array and select the alpha channel to create a boolean mask
ko_np = np.array(ko)
ko_mask = ko_np[:,:,3] / 255.0 # Selects the alpha channel and converts to float between 0 and 1
# limit the "ko" mask by the feathered "biggest_mask", in high res
final_mask = alpha_np * ko_mask
# Normalize the final_mask to 8-bit range
final_mask_img = Image.fromarray(np.uint8(final_mask * 255))
# Use the final_mask_img when pasting
bg.paste(fg, (0, 0), final_mask_img)
if should_harmonize:
# now run the harmonizer to make sure the foreground matches the background
print("Running harmonizer...")
harmonized = harmonizer(bg, final_mask_img)
if harmonized != None:
return harmonized
else:
return bg
else:
return bg
def harmonizer(comp, mask):
try:
import torchvision.transforms.functional as tf
from harmonizer.src import model
harmonizer = model.Harmonizer()
harmonizer = harmonizer.cuda()
harmonizer.load_state_dict(torch.load("harmonizer/pretrained/harmonizer.pth"), strict=True)
harmonizer.eval()
comp = tf.to_tensor(comp)[None, ...]
mask = tf.to_tensor(mask)[None, ...]
comp = comp.cuda()
mask = mask.cuda()
with torch.no_grad():
arguments = harmonizer.predict_arguments(comp, mask)
harmonized = harmonizer.restore_image(comp, mask, arguments)[-1]
harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255
harmonized = Image.fromarray(harmonized.astype(np.uint8))
return harmonized
except:
return None
def rvm(fg):
model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50"
model.load_state_dict(torch.load('rvm/rvm_mobilenetv3.pth'))
transform = transforms.ToTensor()
fg = transform(fg).to("cuda", dtype=torch.float32, non_blocking=True)
rec = [None] * 4 # Initial recurrent states.
downsample_ratio = 1
with torch.no_grad():
fg = fg.unsqueeze(0) # [B, T, C, H, W]
fgr, pha, *rec = model(fg, *rec, downsample_ratio)
# return fgr[0], pha[0]
return to_pil_image(pha[0])
def csai(unique_id):
headers = {
'x-api-key': 'BB7gqZPZ8d18gYIPoYsgcaORqINBFQpK9IFTby0a',
'Content-Type': 'application/json'
}
ko_url = AWS_UPLOAD_URL + unique_id + "/fg.jpg"
payload = json.dumps({
"url": ko_url,
"async": True
})
print("Calling KO on " + ko_url)
# call the KO
response = requests.request("POST", CSAI_LAMBDA_URL + "ko", headers=headers, data=payload)
if response.status_code != 200:
print("KO failed with status code: " + str(response.status_code) + " and message: " + response.text)
else:
started = time.time()
ko = None
while ko == None:
# call the KO
response = requests.request("POST", CSAI_LAMBDA_URL + "koresult", headers=headers, data=payload)
# we might get a 303 redirection if the process is done
if response.status_code == 303:
# get the location from header
location = response.headers['Location']
# get the image from the URL
response = requests.get(location)
ko = Image.open(BytesIO(response.content))
# but response follows that redirection by default, and in this case we get the image directly
elif response.status_code == 200:
# get the buffer from the response
buffer = BytesIO(response.content)
# open the image from the buffer
ko = Image.open(buffer)
elif response.status_code == 404:
if time.time() - started > 60 * 3:
print("KO timed out")
break
else:
print("KO not ready yet, waiting..." + response.text)
# note from, Marc-Andre regarding throttling/rate limit: présentement staging est configuré "No throttling" et "No quota" et le account default est 10000 requests per second with a burst of 5000 requests
time.sleep(0.5)
else:
print("KO failed with status code: " + str(response.status_code) + " and message: " + response.text)
break
return ko