import pickle import pandas as p import numpy as np import requests import io import os import cv2 import gdown import tempfile from PIL import Image, ImageDraw, ImageFont import PIL from transparent_background import Remover import torch import torch.nn.functional as F import time import gradio as gr from PIL import Image import requests from io import BytesIO from torchvision import datasets, models, transforms class BackgroundRemover(Remover): def __init__(self, model_bytes, device=None): """ model_bytes: model weights as bytes (downloaded from "https://drive.google.com/file/d/13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY/view?usp=share_link") device : (default cuda:0 if available) specifying device for computation """ self.model_path = None with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file: tmp_file.write(model_bytes) self.model_path = tmp_file.name # get the path of the script that defines this class script_path = "" #os.path.abspath(__file__) # construct the path to the arial.ttf file relative to the script location font_path = os.path.join(os.path.dirname(script_path), "arial.ttf") self.font_path = font_path super().__init__(fast=False, jit=False, device=device, ckpt=self.model_path) def __del__(self): if self.model_path is not None and os.path.exists(self.model_path): os.remove(self.model_path) def download(self): pass def predict(self, image, comparison=False, extra=""): s = time.time() prediction = self.raw_predict(image) e = time.time() #print(f"predict time {e-s:.4f}") if not comparison: return prediction else: return self.compare(image, prediction, e-s, extra) def raw_predict(self, image, empty_cache_after_prediction=False): t1 = time.time() out = self.process(image) t2 = time.time() prediction = Image.fromarray(out) # Crea una nueva imagen RGB con un fondo blanco del mismo tamaño que la original new_image = Image.new("RGB", prediction.size, (255, 255, 255)) # Combina las dos imágenes, reemplazando los píxeles transparentes con blanco new_image.paste(prediction, mask=prediction.split()[3]) t3 = time.time() if empty_cache_after_prediction and "cuda" in self.device: torch.cuda.empty_cache() t4 = time.time() #print(f"{(t2-t1)*1000:.4f} {(t3-t2)*1000:.4f} {(t4-t3)*1000:.4f}") return new_image def compare(self, image1, image2, prediction_time, extra_info=""): extra = 80 concatenated_image = Image.new('RGB', (image1.width + image2.width, image1.height + extra), (255, 255, 255)) concatenated_image.paste(image1, (0, 0+extra)) concatenated_image.paste(image2, (image1.width, 0+extra)) draw = ImageDraw.Draw(concatenated_image) font = ImageFont.truetype(self.font_path, 20) draw.text((20, 0), f"size:{image1.size}\nmodel time:{prediction_time:.2f}s\n{extra_info}", fill=(0, 0, 0), font=font) return concatenated_image def read_image_from_url(self, url): response = requests.get(url) image = Image.open(io.BytesIO(response.content)).convert("RGB") return image def read_image_from_file(self, file_name): image = Image.open(file_name).convert("RGB") return image def read_image_form_bytes(self, image_bytes): # Convertir los bytes en imagen image = Image.open(io.BytesIO(image_bytes)) return image def image_to_bytes(self, image, format="JPEG"): image_bytes = io.BytesIO() image_rgb = image.convert('RGB') image_rgb.save(image_bytes, format=format) image_bytes = image_bytes.getvalue() return image_bytes @classmethod def create_instance_from_model_url(cls, url): model_bytes = BackgroundRemover.download_model_from_url(url) return cls(model_bytes) @classmethod def create_instance_from_model_file(cls, file_path, device=None): with open(file_path, 'rb') as f: model_bytes = f.read() return cls(model_bytes, device) @classmethod def download_model_from_url(cls, url): with io.BytesIO() as file: gdown.download(url, file, quiet=False, fuzzy=True) # Get the contents of the file as bytes file.seek(0) model_bytes = file.read() return model_bytes def show_image(url: str): response = requests.get(url) img = Image.open(BytesIO(response.content)) return img def do_predictions(url): response = requests.get(url) img = Image.open(BytesIO(response.content)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") transform_model = BackgroundRemover.create_instance_from_model_file("model_weights.pth") # Set up data transformations data_transforms = { 'train': transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), } out = transform_model.predict(img, comparison=False) return img, out iface = gr.Interface(fn=do_predictions, inputs="text", examples=[["https://http2.mlstatic.com/D_NQ_NP_2X_823376-MLU29226703936_012019-F.webp"], ["https://http2.mlstatic.com/D_781350-MLA53584851929_022023-F.jpg"]], outputs=["image", "image"], ) iface.launch()