Spaces:
Runtime error
Runtime error
| from PIL import Image, ImageOps | |
| import requests | |
| import json | |
| import uuid, os, time | |
| from io import BytesIO | |
| import numpy as np | |
| import torch | |
| from rvm.model import MattingNetwork | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import to_pil_image | |
| from s3utils import * | |
| def process(fg, bg, should_harmonize=True): | |
| # make sure the FG is > 512 on its smallest side | |
| if min(fg.size) < 513: | |
| ratio = 513 / min(fg.size) | |
| new_size = (int(fg.size[0]*ratio), int(fg.size[1]*ratio)) | |
| fg = fg.resize(new_size, Image.Resampling.LANCZOS) | |
| # make sure the FG2 is 512 on its largest side | |
| ratio = 513 / max(fg.size) | |
| new_size = (int(fg.size[0]*ratio), int(fg.size[1]*ratio)) | |
| fg2 = fg.resize(new_size, Image.Resampling.LANCZOS) | |
| # if needed, crop, resize and center the BG on the FG | |
| if bg.size != fg.size: | |
| print("BG image is not the same size as FG, adjusting...") | |
| # Resize the BG to match the FG, maintaining its aspect ratio | |
| bg = ImageOps.fit(bg, fg.size, Image.Resampling.LANCZOS) | |
| # If the BG is still smaller than the FG (can happen due to aspect ratio differences), | |
| # then create a new image of the FG size and paste the BG into the center of it. | |
| if bg.size != fg.size: | |
| new_bg = Image.new("RGBA", fg.size) | |
| new_bg.paste(bg, ((fg.size[0] - bg.size[0]) // 2, | |
| (fg.size[1] - bg.size[1]) // 2)) | |
| bg = new_bg | |
| # create an ID for CSAI | |
| unique_id = str(uuid.uuid4()) | |
| save_image_to_s3(fg, 'JPEG', unique_id + '/fg.jpg') | |
| # process the ko with CSAI | |
| ko = csai(unique_id) | |
| if ko == None: | |
| print(f'KO failed for {unique_id}') | |
| return None | |
| alpha_rvm = rvm(fg2) | |
| # resize alpha to match fg and bg | |
| alpha_rvm = alpha_rvm.resize(fg.size, Image.Resampling.LANCZOS) | |
| alpha_np = np.array(alpha_rvm) / 255.0 | |
| # Convert 'ko' image to a numpy array and select the alpha channel to create a boolean mask | |
| ko_np = np.array(ko) | |
| ko_mask = ko_np[:,:,3] / 255.0 # Selects the alpha channel and converts to float between 0 and 1 | |
| # limit the "ko" mask by the feathered "biggest_mask", in high res | |
| final_mask = alpha_np * ko_mask | |
| # Normalize the final_mask to 8-bit range | |
| final_mask_img = Image.fromarray(np.uint8(final_mask * 255)) | |
| # Use the final_mask_img when pasting | |
| bg.paste(fg, (0, 0), final_mask_img) | |
| if should_harmonize: | |
| # now run the harmonizer to make sure the foreground matches the background | |
| print("Running harmonizer...") | |
| harmonized = harmonizer(bg, final_mask_img) | |
| if harmonized != None: | |
| return harmonized | |
| else: | |
| return bg | |
| else: | |
| return bg | |
| def harmonizer(comp, mask): | |
| try: | |
| import torchvision.transforms.functional as tf | |
| from harmonizer.src import model | |
| harmonizer = model.Harmonizer() | |
| harmonizer = harmonizer.cuda() | |
| harmonizer.load_state_dict(torch.load("harmonizer/pretrained/harmonizer.pth"), strict=True) | |
| harmonizer.eval() | |
| comp = tf.to_tensor(comp)[None, ...] | |
| mask = tf.to_tensor(mask)[None, ...] | |
| comp = comp.cuda() | |
| mask = mask.cuda() | |
| with torch.no_grad(): | |
| arguments = harmonizer.predict_arguments(comp, mask) | |
| harmonized = harmonizer.restore_image(comp, mask, arguments)[-1] | |
| harmonized = np.transpose(harmonized[0].cpu().numpy(), (1, 2, 0)) * 255 | |
| harmonized = Image.fromarray(harmonized.astype(np.uint8)) | |
| return harmonized | |
| except: | |
| return None | |
| def rvm(fg): | |
| model = MattingNetwork('mobilenetv3').eval().cuda() # or "resnet50" | |
| model.load_state_dict(torch.load('rvm/rvm_mobilenetv3.pth')) | |
| transform = transforms.ToTensor() | |
| fg = transform(fg).to("cuda", dtype=torch.float32, non_blocking=True) | |
| rec = [None] * 4 # Initial recurrent states. | |
| downsample_ratio = 1 | |
| with torch.no_grad(): | |
| fg = fg.unsqueeze(0) # [B, T, C, H, W] | |
| fgr, pha, *rec = model(fg, *rec, downsample_ratio) | |
| # return fgr[0], pha[0] | |
| return to_pil_image(pha[0]) | |
| def csai(unique_id): | |
| headers = { | |
| 'x-api-key': 'BB7gqZPZ8d18gYIPoYsgcaORqINBFQpK9IFTby0a', | |
| 'Content-Type': 'application/json' | |
| } | |
| ko_url = AWS_UPLOAD_URL + unique_id + "/fg.jpg" | |
| payload = json.dumps({ | |
| "url": ko_url, | |
| "async": True | |
| }) | |
| print("Calling KO on " + ko_url) | |
| # call the KO | |
| response = requests.request("POST", CSAI_LAMBDA_URL + "ko", headers=headers, data=payload) | |
| if response.status_code != 200: | |
| print("KO failed with status code: " + str(response.status_code) + " and message: " + response.text) | |
| else: | |
| started = time.time() | |
| ko = None | |
| while ko == None: | |
| # call the KO | |
| response = requests.request("POST", CSAI_LAMBDA_URL + "koresult", headers=headers, data=payload) | |
| # we might get a 303 redirection if the process is done | |
| if response.status_code == 303: | |
| # get the location from header | |
| location = response.headers['Location'] | |
| # get the image from the URL | |
| response = requests.get(location) | |
| ko = Image.open(BytesIO(response.content)) | |
| # but response follows that redirection by default, and in this case we get the image directly | |
| elif response.status_code == 200: | |
| # get the buffer from the response | |
| buffer = BytesIO(response.content) | |
| # open the image from the buffer | |
| ko = Image.open(buffer) | |
| elif response.status_code == 404: | |
| if time.time() - started > 60 * 3: | |
| print("KO timed out") | |
| break | |
| else: | |
| print("KO not ready yet, waiting..." + response.text) | |
| # note from, Marc-Andre regarding throttling/rate limit: présentement staging est configuré "No throttling" et "No quota" et le account default est 10000 requests per second with a burst of 5000 requests | |
| time.sleep(0.5) | |
| else: | |
| print("KO failed with status code: " + str(response.status_code) + " and message: " + response.text) | |
| break | |
| return ko | |