from PIL import Image, ImageOps import requests import json import uuid, os, time from io import BytesIO import numpy as np import torch from rvm.model import MattingNetwork from torchvision import transforms from torchvision.transforms.functional import to_pil_image from s3utils import * def process(fg, bg, should_harmonize=True): # make sure the FG is > 512 on its smallest side if min(fg.size) < 513: ratio = 513 / min(fg.size) new_size = (int(fg.size[0]*ratio), int(fg.size[1]*ratio)) fg = fg.resize(new_size, Image.Resampling.LANCZOS) # make sure the FG2 is 512 on its largest side ratio = 513 / max(fg.size) new_size = (int(fg.size[0]*ratio), int(fg.size[1]*ratio)) fg2 = fg.resize(new_size, Image.Resampling.LANCZOS) # if needed, crop, resize and center the BG on the FG if bg.size != fg.size: print("BG image is not the same size as FG, adjusting...") # Resize the BG to match the FG, maintaining its aspect ratio bg = ImageOps.fit(bg, fg.size, Image.Resampling.LANCZOS) # If the BG is still smaller than the FG (can happen due to aspect ratio differences), # then create a new image of the FG size and paste the BG into the center of it. if bg.size != fg.size: new_bg = Image.new("RGBA", fg.size) new_bg.paste(bg, ((fg.size[0] - bg.size[0]) // 2, (fg.size[1] - bg.size[1]) // 2)) bg = new_bg # create an ID for CSAI unique_id = str(uuid.uuid4()) save_image_to_s3(fg, 'JPEG', unique_id + '/fg.jpg') # process the ko with CSAI ko = csai(unique_id) if ko == None: print(f'KO failed for {unique_id}') return None alpha_rvm = rvm(fg2) # resize alpha to match fg and bg alpha_rvm = alpha_rvm.resize(fg.size, Image.Resampling.LANCZOS) alpha_np = np.array(alpha_rvm) / 255.0 # Convert 'ko' image to a numpy array and select the alpha channel to create a boolean mask ko_np = np.array(ko) ko_mask = ko_np[:,:,3] / 255.0 # Selects the alpha channel and converts to float between 0 and 1 # limit the "ko" mask by the feathered "biggest_mask", in high res final_mask = alpha_np * ko_mask # Normalize the final_mask to 8-bit range final_mask_img = Image.fromarray(np.uint8(final_mask * 255)) # Use the final_mask_img when pasting bg.paste(fg, (0, 0), final_mask_img) if should_harmonize: # now run the harmonizer to make sure the foreground matches the background print("Running harmonizer...") harmonized = harmonizer(bg, final_mask_img) if harmonized != None: return harmonized else: return bg else: return bg def harmonizer(comp, mask): try: import torchvision.transforms.functional as tf from harmonizer.src import model harmonizer = model.Harmonizer() harmonizer = harmonizer.cuda() harmonizer.load_state_dict(torch.load("harmonizer/pretrained/harmonizer.pth"), strict=True) harmonizer.eval() comp = tf.to_tensor(comp)[None, ...] mask = tf.to_tensor(mask)[None, ...] comp = comp.cuda() mask = mask.cuda() with torch.no_grad(): arguments = harmonizer.predict_arguments(comp, mask) harmonized = harmonizer.restore_image(comp, mask, arguments)[-1] harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255 harmonized = Image.fromarray(harmonized.astype(np.uint8)) return harmonized except: return None def rvm(fg): model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50" model.load_state_dict(torch.load('rvm/rvm_mobilenetv3.pth')) transform = transforms.ToTensor() fg = transform(fg).to("cuda", dtype=torch.float32, non_blocking=True) rec = [None] * 4 # Initial recurrent states. downsample_ratio = 1 with torch.no_grad(): fg = fg.unsqueeze(0) # [B, T, C, H, W] fgr, pha, *rec = model(fg, *rec, downsample_ratio) # return fgr[0], pha[0] return to_pil_image(pha[0]) def csai(unique_id): headers = { 'x-api-key': 'BB7gqZPZ8d18gYIPoYsgcaORqINBFQpK9IFTby0a', 'Content-Type': 'application/json' } ko_url = AWS_UPLOAD_URL + unique_id + "/fg.jpg" payload = json.dumps({ "url": ko_url, "async": True }) print("Calling KO on " + ko_url) # call the KO response = requests.request("POST", CSAI_LAMBDA_URL + "ko", headers=headers, data=payload) if response.status_code != 200: print("KO failed with status code: " + str(response.status_code) + " and message: " + response.text) else: started = time.time() ko = None while ko == None: # call the KO response = requests.request("POST", CSAI_LAMBDA_URL + "koresult", headers=headers, data=payload) # we might get a 303 redirection if the process is done if response.status_code == 303: # get the location from header location = response.headers['Location'] # get the image from the URL response = requests.get(location) ko = Image.open(BytesIO(response.content)) # but response follows that redirection by default, and in this case we get the image directly elif response.status_code == 200: # get the buffer from the response buffer = BytesIO(response.content) # open the image from the buffer ko = Image.open(buffer) elif response.status_code == 404: if time.time() - started > 60 * 3: print("KO timed out") break else: print("KO not ready yet, waiting..." + response.text) # note from, Marc-Andre regarding throttling/rate limit: présentement staging est configuré "No throttling" et "No quota" et le account default est 10000 requests per second with a burst of 5000 requests time.sleep(0.5) else: print("KO failed with status code: " + str(response.status_code) + " and message: " + response.text) break return ko