Mrmusculo commited on
Commit
3a0bfe5
1 Parent(s): 08c4f37

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -0
app.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
3
+ import numpy as np
4
+
5
+ import requests
6
+ import io
7
+ import os
8
+ import cv2
9
+ import gdown
10
+ import tempfile
11
+
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import PIL
14
+
15
+ from transparent_background import Remover
16
+ import torch
17
+
18
+ import time
19
+
20
+ import gradio as gr
21
+ from PIL import Image
22
+ import requests
23
+ from io import BytesIO
24
+
25
+ class BackgroundRemover(Remover):
26
+ def __init__(self, model_bytes, device=None):
27
+ """
28
+ model_bytes: model weights as bytes (downloaded from "https://drive.google.com/file/d/13oBl5MTVcWER3YU4fSxW3ATlVfueFQPY/view?usp=share_link")
29
+ device : (default cuda:0 if available) specifying device for computation
30
+ """
31
+
32
+ self.model_path = None
33
+ with tempfile.NamedTemporaryFile(suffix=".pth", delete=False) as tmp_file:
34
+ tmp_file.write(model_bytes)
35
+ self.model_path = tmp_file.name
36
+
37
+
38
+ # get the path of the script that defines this class
39
+ script_path = os.path.abspath(__file__)
40
+
41
+ # construct the path to the arial.ttf file relative to the script location
42
+ font_path = os.path.join(os.path.dirname(script_path), "arial.ttf")
43
+
44
+ self.font_path = font_path
45
+
46
+ super().__init__(fast=False, jit=False, device=device, ckpt=self.model_path)
47
+
48
+ def __del__(self):
49
+ if self.model_path is not None and os.path.exists(self.model_path):
50
+ os.remove(self.model_path)
51
+
52
+ def download(self):
53
+ pass
54
+
55
+ def predict(self, image, comparison=False, extra=""):
56
+
57
+ s = time.time()
58
+ prediction = self.raw_predict(image)
59
+ e = time.time()
60
+ #print(f"predict time {e-s:.4f}")
61
+
62
+ if not comparison:
63
+ return prediction
64
+ else:
65
+ return self.compare(image, prediction, e-s, extra)
66
+
67
+ def raw_predict(self, image, empty_cache_after_prediction=False):
68
+
69
+ t1 = time.time()
70
+ out = self.process(image)
71
+ t2 = time.time()
72
+
73
+ prediction = Image.fromarray(out)
74
+
75
+ # Crea una nueva imagen RGB con un fondo blanco del mismo tamaño que la original
76
+ new_image = Image.new("RGB", prediction.size, (255, 255, 255))
77
+
78
+ # Combina las dos imágenes, reemplazando los píxeles transparentes con blanco
79
+ new_image.paste(prediction, mask=prediction.split()[3])
80
+
81
+ t3 = time.time()
82
+
83
+ if empty_cache_after_prediction and "cuda" in self.device:
84
+ torch.cuda.empty_cache()
85
+
86
+ t4 = time.time()
87
+
88
+ #print(f"{(t2-t1)*1000:.4f} {(t3-t2)*1000:.4f} {(t4-t3)*1000:.4f}")
89
+
90
+ return new_image
91
+
92
+ def compare(self, image1, image2, prediction_time, extra_info=""):
93
+ extra = 80
94
+
95
+ concatenated_image = Image.new('RGB', (image1.width + image2.width, image1.height + extra), (255, 255, 255))
96
+ concatenated_image.paste(image1, (0, 0+extra))
97
+ concatenated_image.paste(image2, (image1.width, 0+extra))
98
+
99
+ draw = ImageDraw.Draw(concatenated_image)
100
+
101
+ font = ImageFont.truetype(self.font_path, 20)
102
+ draw.text((20, 0), f"size:{image1.size}\nmodel time:{prediction_time:.2f}s\n{extra_info}", fill=(0, 0, 0), font=font)
103
+
104
+ return concatenated_image
105
+
106
+ def read_image_from_url(self, url):
107
+ response = requests.get(url)
108
+ image = Image.open(io.BytesIO(response.content)).convert("RGB")
109
+
110
+ return image
111
+
112
+ def read_image_from_file(self, file_name):
113
+
114
+ image = Image.open(file_name).convert("RGB")
115
+
116
+ return image
117
+
118
+ def read_image_form_bytes(self, image_bytes):
119
+
120
+ # Convertir los bytes en imagen
121
+ image = Image.open(io.BytesIO(image_bytes))
122
+ return image
123
+
124
+ def image_to_bytes(self, image, format="JPEG"):
125
+ image_bytes = io.BytesIO()
126
+ image_rgb = image.convert('RGB')
127
+ image_rgb.save(image_bytes, format=format)
128
+ image_bytes = image_bytes.getvalue()
129
+
130
+ return image_bytes
131
+
132
+ @classmethod
133
+ def create_instance_from_model_url(cls, url):
134
+ model_bytes = BackgroundRemover.download_model_from_url(url)
135
+
136
+ return cls(model_bytes)
137
+
138
+ @classmethod
139
+ def create_instance_from_model_file(cls, file_path, device=None):
140
+ with open(file_path, 'rb') as f:
141
+ model_bytes = f.read()
142
+
143
+ return cls(model_bytes, device)
144
+
145
+ @classmethod
146
+ def download_model_from_url(cls, url):
147
+ with io.BytesIO() as file:
148
+ gdown.download(url, file, quiet=False, fuzzy=True)
149
+
150
+ # Get the contents of the file as bytes
151
+ file.seek(0)
152
+ model_bytes = file.read()
153
+
154
+ return model_bytes
155
+
156
+
157
+
158
+ def show_image(url: str):
159
+ response = requests.get(url)
160
+ img = Image.open(BytesIO(response.content))
161
+ return img
162
+
163
+ def do_predictions(url):
164
+ response = requests.get(url)
165
+ img = Image.open(BytesIO(response.content))
166
+
167
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
168
+
169
+ transform_model = BackgroundRemover.create_instance_from_model_file("model_weights.pth")
170
+
171
+ # Set up data transformations
172
+ data_transforms = {
173
+ 'train': transforms.Compose([
174
+ #transforms.Resize(512), #256
175
+ #transforms.CenterCrop(480), # 224
176
+ #transforms.Resize((256, 256)),
177
+ transforms.Resize((384, 384)),
178
+ transforms.ToTensor(),
179
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
180
+ ]),
181
+ 'val': transforms.Compose([
182
+ #transforms.Resize(512), #256
183
+ #transforms.CenterCrop(480), # 224
184
+ #transforms.Resize((256, 256)),
185
+ transforms.Resize((384, 284)),
186
+ transforms.ToTensor(),
187
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
188
+ ]),
189
+ }
190
+
191
+ # Crear un modelo con la misma arquitectura
192
+ detect_model = models.resnet50(weights=None) # Cambiar 'pretrained' por 'weights'
193
+ num_ftrs = detect_model.fc.in_features
194
+ num_classes = 2
195
+ detect_model.fc = nn.Linear(num_ftrs, num_classes)
196
+ detect_model = detect_model.to(device)
197
+
198
+ # Cargar los pesos guardados
199
+ model_weights_path = 'white_background_detection/resnet50_finetuned_weights.pth'
200
+ detect_model.load_state_dict(torch.load(model_weights_path))
201
+
202
+ # Cambiar el modelo a modo de evaluación
203
+ detect_model.eval()
204
+
205
+ print("")
206
+
207
+ prediction, predicted_probability, inference_time = predict_single_image_detection(img, detect_model, data_transforms['val'], "cuda:0")
208
+
209
+ if prediction=="real":
210
+ out = transform_model.predict(img, comparison=False)
211
+ return prediction, predicted_probability, img, out,
212
+ else:
213
+ return prediction, predicted_probability, img, None
214
+
215
+ iface = gr.Interface(fn=do_predictions, inputs="text", outputs=["text", "text", "image", "image"], examples=[["https://http2.mlstatic.com/D_NQ_NP_2X_823376-MLU29226703936_012019-F.webp"],
216
+ ["https://http2.mlstatic.com/D_781350-MLA53584851929_022023-F.jpg"]])
217
+ #iface.outputs[0].set_title("Predicción")
218
+ #iface.outputs[1].set_title("Clase")
219
+ #iface.outputs[2].set_title("Probabilidad")
220
+ iface.launch(share=True)