Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,13 +6,14 @@ from torch.autograd import Variable
|
|
6 |
import numpy as np
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
import gradio as gr
|
|
|
9 |
import ezdxf
|
10 |
-
from PIL import Image, UnidentifiedImageError
|
11 |
-
import xml.etree.ElementTree as ET
|
12 |
|
|
|
13 |
def normPRED(d):
|
14 |
return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
|
15 |
|
|
|
16 |
def inference(net, input_img):
|
17 |
input_img = input_img / np.max(input_img)
|
18 |
tmpImg = np.zeros((input_img.shape[0], input_img.shape[1], 3))
|
@@ -21,115 +22,38 @@ def inference(net, input_img):
|
|
21 |
tmpImg[:, :, 2] = (input_img[:, :, 0] - 0.485) / 0.229
|
22 |
tmpImg = torch.from_numpy(tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]).type(torch.FloatTensor)
|
23 |
tmpImg = Variable(tmpImg.cuda() if torch.cuda.is_available() else tmpImg)
|
24 |
-
d1,
|
25 |
pred = normPRED(1.0 - d1[:, 0, :, :])
|
26 |
return pred.cpu().data.numpy().squeeze()
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
Returns:
|
36 |
-
list: Danh sách các đường nét (contours) được trích xuất.
|
37 |
-
"""
|
38 |
-
# Tìm các đường nét trong ảnh
|
39 |
-
contours, _ = cv2.findContours(portrait_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
40 |
-
|
41 |
-
# Làm sạch các đường nét
|
42 |
-
contours = [cnt for cnt in contours if cv2.contourArea(cnt) > 100] # Loại bỏ các đường nét nhỏ
|
43 |
-
contours = [cv2.approxPolyDP(cnt, 0.01 * cv2.arcLength(cnt, True), True) for cnt in contours] # Làm trơn đường nét
|
44 |
-
|
45 |
-
print(f"Number of contours: {len(contours)}")
|
46 |
-
for i, contour in enumerate(contours):
|
47 |
-
print(f"Contour {i}: {contour.shape}")
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
def convert_to_dxf(contours, filename="output.dxf"):
|
52 |
-
"""
|
53 |
-
Tạo file DXF từ các đường nét (contours).
|
54 |
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
"""
|
62 |
-
doc = ezdxf.new(dxfversion="R2010")
|
63 |
msp = doc.modelspace()
|
64 |
-
|
65 |
-
# Thêm tất cả các đường nét vào file DXF
|
66 |
for contour in contours:
|
67 |
-
points =
|
68 |
-
msp.
|
69 |
-
|
70 |
-
doc.saveas(filename)
|
71 |
-
return filename
|
72 |
-
|
73 |
-
def convert_to_svg(contours, filename="output.svg"):
|
74 |
-
"""
|
75 |
-
Tạo file SVG từ các đường nét (contours).
|
76 |
-
|
77 |
-
Parameters:
|
78 |
-
contours (list): Danh sách các đường nét.
|
79 |
-
filename (str): Tên file SVG.
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
root = ET.Element("svg", {
|
85 |
-
"xmlns": "http://www.w3.org/2000/svg",
|
86 |
-
"width": "100%",
|
87 |
-
"height": "100%",
|
88 |
-
"viewBox": "0 0 100 100"
|
89 |
-
})
|
90 |
-
|
91 |
-
for contour in contours:
|
92 |
-
if len(contour) > 1:
|
93 |
-
path_data = "M"
|
94 |
-
for point in contour:
|
95 |
-
path_data += f" {point[0]/100} {point[1]/100}"
|
96 |
-
path_data += " Z"
|
97 |
-
path = ET.SubElement(root, "path", {
|
98 |
-
"d": path_data,
|
99 |
-
"fill": "none",
|
100 |
-
"stroke": "black",
|
101 |
-
"stroke-width": "0.5"
|
102 |
-
})
|
103 |
-
elif len(contour) == 1:
|
104 |
-
# Handle the case where the contour has only one point
|
105 |
-
point = contour[0]
|
106 |
-
circle = ET.SubElement(root, "circle", {
|
107 |
-
"cx": f"{point[0][0]/100}",
|
108 |
-
"cy": f"{point[0][1]/100}",
|
109 |
-
"r": "0.5",
|
110 |
-
"fill": "black"
|
111 |
-
})
|
112 |
|
113 |
-
|
114 |
-
tree.write(filename)
|
115 |
-
return filename
|
116 |
-
|
117 |
-
def process_image(img, bw_option):
|
118 |
-
try:
|
119 |
-
img = Image.open(img).convert("RGB")
|
120 |
-
img = np.array(img)
|
121 |
-
if bw_option:
|
122 |
-
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
123 |
-
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
124 |
-
result = inference(u2net, img)
|
125 |
-
result_img = (result * 255).astype(np.uint8)
|
126 |
-
contours = extract_contours(result_img)
|
127 |
-
dxf_path = convert_to_dxf(contours, filename="output.dxf")
|
128 |
-
svg_path = convert_to_svg(contours, filename="output.svg")
|
129 |
-
return result_img, dxf_path, svg_path
|
130 |
-
except UnidentifiedImageError:
|
131 |
-
return "Error: Unable to identify the image file. Please ensure the input file is a valid image.", None, None
|
132 |
|
|
|
133 |
def load_u2net_model():
|
134 |
model_path = hf_hub_download(repo_id="Arrcttacsrks/U2net", filename="u2net_portrait.pth", use_auth_token=os.getenv("HF_TOKEN"))
|
135 |
net = U2NET(3, 1)
|
@@ -137,25 +61,21 @@ def load_u2net_model():
|
|
137 |
net.eval()
|
138 |
return net
|
139 |
|
|
|
140 |
u2net = load_u2net_model()
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
iface.launch(share=True)
|
159 |
-
|
160 |
-
if __name__ == "__main__":
|
161 |
-
main()
|
|
|
6 |
import numpy as np
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
import gradio as gr
|
9 |
+
import math
|
10 |
import ezdxf
|
|
|
|
|
11 |
|
12 |
+
# Chuẩn hóa dự đoán
|
13 |
def normPRED(d):
|
14 |
return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
|
15 |
|
16 |
+
# Hàm suy luận với U2NET
|
17 |
def inference(net, input_img):
|
18 |
input_img = input_img / np.max(input_img)
|
19 |
tmpImg = np.zeros((input_img.shape[0], input_img.shape[1], 3))
|
|
|
22 |
tmpImg[:, :, 2] = (input_img[:, :, 0] - 0.485) / 0.229
|
23 |
tmpImg = torch.from_numpy(tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]).type(torch.FloatTensor)
|
24 |
tmpImg = Variable(tmpImg.cuda() if torch.cuda.is_available() else tmpImg)
|
25 |
+
d1, *, *, *, *, *, * = net(tmpImg)
|
26 |
pred = normPRED(1.0 - d1[:, 0, :, :])
|
27 |
return pred.cpu().data.numpy().squeeze()
|
28 |
|
29 |
+
# Hàm chính để xử lý ảnh đầu vào và trả về ảnh chân dung và DWF file
|
30 |
+
def process_image(img, bw_option):
|
31 |
+
# Chuyển đổi ảnh thành đen trắng nếu được chọn
|
32 |
+
if bw_option:
|
33 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
34 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # Chuyển lại thành ảnh 3 kênh cho mô hình
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
+
# Chạy suy luận để tạo ảnh chân dung
|
37 |
+
result = inference(u2net, img)
|
|
|
|
|
|
|
38 |
|
39 |
+
# Phát hiện và lấy contours từ ảnh chân dung
|
40 |
+
_, threshold = cv2.threshold(np.uint8(result * 255), 0, 255, cv2.THRESH_BINARY)
|
41 |
+
contours, _ = cv2.findContours(threshold, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
42 |
|
43 |
+
# Tạo DWF file từ các contours
|
44 |
+
doc = ezdxf.new('R2010')
|
|
|
|
|
45 |
msp = doc.modelspace()
|
|
|
|
|
46 |
for contour in contours:
|
47 |
+
points = [tuple(p[0]) for p in contour]
|
48 |
+
msp.add_polyline2d(points)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
# Lưu DWF file
|
51 |
+
dwf_file = 'portrait_result.dxf'
|
52 |
+
doc.saveas(dwf_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
return (result * 255).astype(np.uint8), dwf_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
# Tải mô hình từ Hugging Face Hub
|
57 |
def load_u2net_model():
|
58 |
model_path = hf_hub_download(repo_id="Arrcttacsrks/U2net", filename="u2net_portrait.pth", use_auth_token=os.getenv("HF_TOKEN"))
|
59 |
net = U2NET(3, 1)
|
|
|
61 |
net.eval()
|
62 |
return net
|
63 |
|
64 |
+
# Khởi tạo mô hình U2NET
|
65 |
u2net = load_u2net_model()
|
66 |
|
67 |
+
# Tạo giao diện với Gradio
|
68 |
+
iface = gr.Interface(
|
69 |
+
fn=process_image,
|
70 |
+
inputs=[
|
71 |
+
gr.Image(type="numpy", label="Upload your image"),
|
72 |
+
gr.Checkbox(label="Convert to Black & White?", value=False)
|
73 |
+
],
|
74 |
+
outputs=[
|
75 |
+
gr.Image(type="numpy", label="Portrait Result"),
|
76 |
+
gr.File(label="DWF File")
|
77 |
+
],
|
78 |
+
title="Portrait Generation with U2NET",
|
79 |
+
description="Upload an image to generate its portrait and DWF file."
|
80 |
+
)
|
81 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|