Spaces:
Runtime error
Runtime error
import pickle | |
import pandas as p | |
import numpy as np | |
import requests | |
import io | |
import os | |
import cv2 | |
import gdown | |
import tempfile | |
from PIL import Image, ImageDraw, ImageFont | |
import PIL | |
from transparent_background import Remover | |
import torch | |
import torch.nn.functional as F | |
import time | |
import gradio as gr | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
from torchvision import datasets, models, transforms | |
class BackgroundRemover(Remover): | |
def __init__(self, model_bytes, device=None): | |
""" | |
model_bytes: model weights as bytes (downloaded from "https://drive.google.com/file/d/13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY/view?usp=share_link") | |
device : (default cuda:0 if available) specifying device for computation | |
""" | |
self.model_path = None | |
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file: | |
tmp_file.write(model_bytes) | |
self.model_path = tmp_file.name | |
# get the path of the script that defines this class | |
script_path = "" #os.path.abspath(__file__) | |
# construct the path to the arial.ttf file relative to the script location | |
font_path = os.path.join(os.path.dirname(script_path), "arial.ttf") | |
self.font_path = font_path | |
super().__init__(fast=False, jit=False, device=device, ckpt=self.model_path) | |
def __del__(self): | |
if self.model_path is not None and os.path.exists(self.model_path): | |
os.remove(self.model_path) | |
def download(self): | |
pass | |
def predict(self, image, comparison=False, extra=""): | |
s = time.time() | |
prediction = self.raw_predict(image) | |
e = time.time() | |
#print(f"predict time {e-s:.4f}") | |
if not comparison: | |
return prediction | |
else: | |
return self.compare(image, prediction, e-s, extra) | |
def raw_predict(self, image, empty_cache_after_prediction=False): | |
t1 = time.time() | |
out = self.process(image) | |
t2 = time.time() | |
prediction = Image.fromarray(out) | |
# Crea una nueva imagen RGB con un fondo blanco del mismo tamaño que la original | |
new_image = Image.new("RGB", prediction.size, (255, 255, 255)) | |
# Combina las dos imágenes, reemplazando los píxeles transparentes con blanco | |
new_image.paste(prediction, mask=prediction.split()[3]) | |
t3 = time.time() | |
if empty_cache_after_prediction and "cuda" in self.device: | |
torch.cuda.empty_cache() | |
t4 = time.time() | |
#print(f"{(t2-t1)*1000:.4f} {(t3-t2)*1000:.4f} {(t4-t3)*1000:.4f}") | |
return new_image | |
def compare(self, image1, image2, prediction_time, extra_info=""): | |
extra = 80 | |
concatenated_image = Image.new('RGB', (image1.width + image2.width, image1.height + extra), (255, 255, 255)) | |
concatenated_image.paste(image1, (0, 0+extra)) | |
concatenated_image.paste(image2, (image1.width, 0+extra)) | |
draw = ImageDraw.Draw(concatenated_image) | |
font = ImageFont.truetype(self.font_path, 20) | |
draw.text((20, 0), f"size:{image1.size}\nmodel time:{prediction_time:.2f}s\n{extra_info}", fill=(0, 0, 0), font=font) | |
return concatenated_image | |
def read_image_from_url(self, url): | |
response = requests.get(url) | |
image = Image.open(io.BytesIO(response.content)).convert("RGB") | |
return image | |
def read_image_from_file(self, file_name): | |
image = Image.open(file_name).convert("RGB") | |
return image | |
def read_image_form_bytes(self, image_bytes): | |
# Convertir los bytes en imagen | |
image = Image.open(io.BytesIO(image_bytes)) | |
return image | |
def image_to_bytes(self, image, format="JPEG"): | |
image_bytes = io.BytesIO() | |
image_rgb = image.convert('RGB') | |
image_rgb.save(image_bytes, format=format) | |
image_bytes = image_bytes.getvalue() | |
return image_bytes | |
def create_instance_from_model_url(cls, url): | |
model_bytes = BackgroundRemover.download_model_from_url(url) | |
return cls(model_bytes) | |
def create_instance_from_model_file(cls, file_path, device=None): | |
with open(file_path, 'rb') as f: | |
model_bytes = f.read() | |
return cls(model_bytes, device) | |
def download_model_from_url(cls, url): | |
with io.BytesIO() as file: | |
gdown.download(url, file, quiet=False, fuzzy=True) | |
# Get the contents of the file as bytes | |
file.seek(0) | |
model_bytes = file.read() | |
return model_bytes | |
def show_image(url: str): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
return img | |
def do_predictions(url): | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
transform_model = BackgroundRemover.create_instance_from_model_file("model_weights.pth") | |
# Set up data transformations | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
'val': transforms.Compose([ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]), | |
} | |
out = transform_model.predict(img, comparison=False) | |
return img, out | |
iface = gr.Interface(fn=do_predictions, inputs="text", | |
examples=[["https://http2.mlstatic.com/D_NQ_NP_2X_823376-MLU29226703936_012019-F.webp"], | |
["https://http2.mlstatic.com/D_781350-MLA53584851929_022023-F.jpg"]], | |
outputs=["image", "image"], | |
) | |
iface.launch() |