#!pip install torch torchvision requests gdown matplotlib opencv-python Pillow==8.0.0 gradio cvzone import cv2 import gradio as gr import os from PIL import Image import numpy as np import torch from torch.autograd import Variable from torchvision import transforms import torch.nn.functional as F import gdown import matplotlib.pyplot as plt import warnings import cv2 import matplotlib.pyplot as plt import cvzone import torch import numpy as np import requests from io import BytesIO import requests import gradio as gr import base64 import os import requests import io from PIL import Image from PIL import ImageOps from PIL import ImageFilter from PIL import ImageChops import json from openai import OpenAI warnings.filterwarnings("ignore") engine_id = "stable-inpainting-512-v2-0" api_host = os.getenv('API_HOST', 'https://api.stability.ai') api_key = os.environ['Stability_Key'] if api_key is None: raise Exception("Missing Stability API key.") os.system("git clone https://github.com/xuebinqin/DIS") os.system("mv DIS/IS-Net/* .") from data_loader_cache import normalize, im_reader, im_preprocess from models import * device = 'cuda' if torch.cuda.is_available() else 'cpu' if not os.path.exists("saved_models"): os.mkdir("saved_models") MODEL_PATH_URL = "https://drive.google.com/uc?id=1xRaOTDapXoicYGOk-fQk5eSKOkjZ9Cni" gdown.download(MODEL_PATH_URL, "saved_models/isnet.pth", use_cookies=False) class GOSNormalize(object): ''' Normalize the Image using torch.transforms ''' def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): self.mean = mean self.std = std def __call__(self,image): image = normalize(image,self.mean,self.std) return image transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])]) def load_image(im_path, hypar): im = im_reader(im_path) im, im_shp = im_preprocess(im, hypar["cache_size"]) im = torch.divide(im,255.0) shape = torch.from_numpy(np.array(im_shp)) return transform(im).unsqueeze(0), shape.unsqueeze(0) def build_model(hypar,device): net = hypar["model"] if(hypar["model_digit"]=="half"): net.half() for layer in net.modules(): if isinstance(layer, nn.BatchNorm2d): layer.float() net.to(device) if(hypar["restore_model"]!=""): net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device)) net.to(device) net.eval() return net def predict(net, inputs_val, shapes_val, hypar, device): ''' Given an Image, predict the mask ''' net.eval() if(hypar["model_digit"]=="full"): inputs_val = inputs_val.type(torch.FloatTensor) else: inputs_val = inputs_val.type(torch.HalfTensor) inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) ds_val = net(inputs_val_v)[0] pred_val = ds_val[0][0,:,:,:] pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear')) ma = torch.max(pred_val) mi = torch.min(pred_val) pred_val = (pred_val-mi)/(ma-mi) if device == 'cuda': torch.cuda.empty_cache() return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) hypar = {} hypar["model_path"] ="./saved_models" hypar["restore_model"] = "isnet.pth" hypar["interm_sup"] = False hypar["model_digit"] = "full" hypar["seed"] = 0 hypar["cache_size"] = [1024, 1024] hypar["input_size"] = [1024, 1024] hypar["crop_size"] = [1024, 1024] hypar["model"] = ISNetDIS() net = build_model(hypar, device) def round_to_multiple(number, multiple,limit): return max(multiple, min(limit, number - number % multiple)) def getPrompt(desc): client = OpenAI(api_key = os.environ['OpenApi_Key']) #openai.api_key = os.environ['OpenApi_Key'] prompt = client.chat.completions.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": '''You are a bot that converts user inputs into stable diffusion promts for a background replacer. For example: User: Add trees and flowers in the background. Prompt Generated: Trees, Flowers, High quality, Ultra HD Keep the prompt within 20 words'''}, {"role": "user", "content": desc}, ] ).choices[0].message.content print(prompt) return prompt def inference2(image: Image,prompt): image_path = image image_tensor, orig_size = load_image(image_path, hypar) mask = predict(net, image_tensor, orig_size, hypar, device) pil_mask = Image.fromarray(mask).convert("L") ogimg = Image.open(image) new_width1 = round_to_multiple(ogimg.width, 64,704) new_height1 = round_to_multiple(ogimg.height, 64,704) init_img = ogimg.resize((new_width1, new_height1)) new_height = round_to_multiple(int(init_img.height * 1.45), 64,1024) new_width = init_img.width new_img = Image.new('RGB', (new_width, new_height), (1, 1, 1)) new_img.paste(init_img, (0, int(new_height * 0.2))) mask_img = pil_mask.resize((new_width1, new_height1)) mask_img2 = mask_img output_buffer = io.BytesIO() new_img.save(output_buffer, format="PNG") resized_image_bytes = output_buffer.getvalue() new_img = Image.new('RGB', (new_width, new_height), (1, 1, 1)) new_img.paste(mask_img, (0, int(new_height * 0.2))) output_buffer2 = io.BytesIO() new_img.save(output_buffer2, format="PNG") resized_image_bytes2 = output_buffer2.getvalue() response = requests.post( f"{api_host}/v1/generation/{engine_id}/image-to-image/masking", headers={ "Accept": 'application/json', "Authorization": f"Bearer {api_key}" }, files={ 'init_image': resized_image_bytes, 'mask_image': resized_image_bytes2, }, data={ "mask_source": "MASK_IMAGE_BLACK", "text_prompts[0][text]": getPrompt(prompt), "text_prompts[0][weight]": 1, "text_prompts[1][text]": "text, distorted text, distortion, bottles, bluriness", "text_prompts[1][weight]": -1, "cfg_scale": 9, "clip_guidance_preset": "FAST_GREEN", "samples": 2, "steps": 40, "style_preset": "photographic", } ) if response.status_code != 200: raise Exception("Non-200 response: " + str(response.text)) data = response.json() image_data1 = data["artifacts"][0]["base64"] image_bytes = base64.b64decode(image_data1) generated_image = Image.open(io.BytesIO(image_bytes)) generated_imaget = generated_image.resize((int(generated_image.width*0.8),int(generated_image.height*0.8))) generated_image.paste(generated_imaget, (int(new_height * 0.08), int(new_height * 0.1))) output_buffer = io.BytesIO() generated_image.save(output_buffer, format="PNG") resized_image_bytes = output_buffer.getvalue() response = requests.post( f"{api_host}/v1/generation/{engine_id}/image-to-image/masking", headers={ "Accept": 'application/json', "Authorization": f"Bearer {api_key}" }, files={ 'init_image': resized_image_bytes, 'mask_image': resized_image_bytes2, }, data={ "mask_source": "MASK_IMAGE_BLACK", "text_prompts[0][text]": getPrompt(prompt), "text_prompts[0][weight]": 1, "text_prompts[1][text]": "white, patches, patch, text, distorted text, distortion, bluriness ", "text_prompts[1][weight]": -1, "cfg_scale": 9, "clip_guidance_preset": "FAST_GREEN", "samples": 2, "steps": 50, "style_preset": "photographic", #"sampler": "DDPM" } ) if response.status_code != 200: raise Exception("Non-200 response: " + str(response.text)) data = response.json() image_data1 = data["artifacts"][0]["base64"] image_bytes = base64.b64decode(image_data1) generated_image = Image.open(io.BytesIO(image_bytes)) im_rgb = Image.open(image).convert("RGB") im_rgba = im_rgb.copy() im_rgba.putalpha(pil_mask) im_rgba = im_rgba.resize((int(new_width1*1),int(new_height1*1)), Image.ANTIALIAS) generated_image.paste(im_rgba, (0, int(new_height * 0.2)), im_rgba) image_data2 = data["artifacts"][1]["base64"] image_bytes = base64.b64decode(image_data2) generated_image2 = Image.open(io.BytesIO(image_bytes)) im_rgb = Image.open(image).convert("RGB") im_rgba = im_rgb.copy() im_rgba.putalpha(pil_mask) im_rgba = im_rgba.resize((int(new_width1*1),int(new_height1*1)), Image.ANTIALIAS) generated_image2.paste(im_rgba, (0, int(new_height * 0.2)), im_rgba) return generated_image,generated_image2 interface = gr.Interface( fn=inference2, inputs=[gr.Image(type='filepath'), gr.Textbox(label = 'prompt')], outputs=["image","image"], title="For best results, use square images and make sure the subject is in the centre", allow_flagging='never', theme="default", cache_examples=False, ).launch( debug=True, share=False)