background / app.py
Mrmusculo's picture
Update app.py
7844023
raw
history blame
No virus
6.2 kB
import pickle
import pandas as p
import numpy as np
import requests
import io
import os
import cv2
import gdown
import tempfile
from PIL import Image, ImageDraw, ImageFont
import PIL
from transparent_background import Remover
import torch
import torch.nn.functional as F
import time
import gradio as gr
from PIL import Image
import requests
from io import BytesIO
from torchvision import datasets, models, transforms
class BackgroundRemover(Remover):
def __init__(self, model_bytes, device=None):
"""
model_bytes: model weights as bytes (downloaded from "https://drive.google.com/file/d/13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY/view?usp=share_link")
device : (default cuda:0 if available) specifying device for computation
"""
self.model_path = None
with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file:
tmp_file.write(model_bytes)
self.model_path = tmp_file.name
# get the path of the script that defines this class
script_path = "" #os.path.abspath(__file__)
# construct the path to the arial.ttf file relative to the script location
font_path = os.path.join(os.path.dirname(script_path), "arial.ttf")
self.font_path = font_path
super().__init__(fast=False, jit=False, device=device, ckpt=self.model_path)
def __del__(self):
if self.model_path is not None and os.path.exists(self.model_path):
os.remove(self.model_path)
def download(self):
pass
def predict(self, image, comparison=False, extra=""):
s = time.time()
prediction = self.raw_predict(image)
e = time.time()
#print(f"predict time {e-s:.4f}")
if not comparison:
return prediction
else:
return self.compare(image, prediction, e-s, extra)
def raw_predict(self, image, empty_cache_after_prediction=False):
t1 = time.time()
out = self.process(image)
t2 = time.time()
prediction = Image.fromarray(out)
# Crea una nueva imagen RGB con un fondo blanco del mismo tamaño que la original
new_image = Image.new("RGB", prediction.size, (255, 255, 255))
# Combina las dos imágenes, reemplazando los píxeles transparentes con blanco
new_image.paste(prediction, mask=prediction.split()[3])
t3 = time.time()
if empty_cache_after_prediction and "cuda" in self.device:
torch.cuda.empty_cache()
t4 = time.time()
#print(f"{(t2-t1)*1000:.4f} {(t3-t2)*1000:.4f} {(t4-t3)*1000:.4f}")
return new_image
def compare(self, image1, image2, prediction_time, extra_info=""):
extra = 80
concatenated_image = Image.new('RGB', (image1.width + image2.width, image1.height + extra), (255, 255, 255))
concatenated_image.paste(image1, (0, 0+extra))
concatenated_image.paste(image2, (image1.width, 0+extra))
draw = ImageDraw.Draw(concatenated_image)
font = ImageFont.truetype(self.font_path, 20)
draw.text((20, 0), f"size:{image1.size}\nmodel time:{prediction_time:.2f}s\n{extra_info}", fill=(0, 0, 0), font=font)
return concatenated_image
def read_image_from_url(self, url):
response = requests.get(url)
image = Image.open(io.BytesIO(response.content)).convert("RGB")
return image
def read_image_from_file(self, file_name):
image = Image.open(file_name).convert("RGB")
return image
def read_image_form_bytes(self, image_bytes):
# Convertir los bytes en imagen
image = Image.open(io.BytesIO(image_bytes))
return image
def image_to_bytes(self, image, format="JPEG"):
image_bytes = io.BytesIO()
image_rgb = image.convert('RGB')
image_rgb.save(image_bytes, format=format)
image_bytes = image_bytes.getvalue()
return image_bytes
@classmethod
def create_instance_from_model_url(cls, url):
model_bytes = BackgroundRemover.download_model_from_url(url)
return cls(model_bytes)
@classmethod
def create_instance_from_model_file(cls, file_path, device=None):
with open(file_path, 'rb') as f:
model_bytes = f.read()
return cls(model_bytes, device)
@classmethod
def download_model_from_url(cls, url):
with io.BytesIO() as file:
gdown.download(url, file, quiet=False, fuzzy=True)
# Get the contents of the file as bytes
file.seek(0)
model_bytes = file.read()
return model_bytes
def show_image(url: str):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
return img
def do_predictions(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform_model = BackgroundRemover.create_instance_from_model_file("model_weights.pth")
# Set up data transformations
data_transforms = {
'train': transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
out = transform_model.predict(img, comparison=False)
return img, out
iface = gr.Interface(fn=do_predictions, inputs="text",
examples=[["https://http2.mlstatic.com/D_NQ_NP_2X_823376-MLU29226703936_012019-F.webp"],
["https://http2.mlstatic.com/D_781350-MLA53584851929_022023-F.jpg"]],
outputs=["image", "image"],
)
iface.launch()