Arrcttacsrks commited on
Commit
4233a4b
·
verified ·
1 Parent(s): 9dd18ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -120
app.py CHANGED
@@ -6,13 +6,14 @@ from torch.autograd import Variable
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
 
9
  import ezdxf
10
- from PIL import Image, UnidentifiedImageError
11
- import xml.etree.ElementTree as ET
12
 
 
13
  def normPRED(d):
14
  return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
15
 
 
16
  def inference(net, input_img):
17
  input_img = input_img / np.max(input_img)
18
  tmpImg = np.zeros((input_img.shape[0], input_img.shape[1], 3))
@@ -21,115 +22,38 @@ def inference(net, input_img):
21
  tmpImg[:, :, 2] = (input_img[:, :, 0] - 0.485) / 0.229
22
  tmpImg = torch.from_numpy(tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]).type(torch.FloatTensor)
23
  tmpImg = Variable(tmpImg.cuda() if torch.cuda.is_available() else tmpImg)
24
- d1, d2, d3, d4, d5, d6, d7 = net(tmpImg)
25
  pred = normPRED(1.0 - d1[:, 0, :, :])
26
  return pred.cpu().data.numpy().squeeze()
27
 
28
- def extract_contours(portrait_mask):
29
- """
30
- Trích xuất các đường nét (contours) từ ảnh chân dung.
31
-
32
- Parameters:
33
- portrait_mask (numpy.ndarray): Ảnh chân dung dạng binary (đen trắng).
34
-
35
- Returns:
36
- list: Danh sách các đường nét (contours) được trích xuất.
37
- """
38
- # Tìm các đường nét trong ảnh
39
- contours, _ = cv2.findContours(portrait_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
40
-
41
- # Làm sạch các đường nét
42
- contours = [cnt for cnt in contours if cv2.contourArea(cnt) > 100] # Loại bỏ các đường nét nhỏ
43
- contours = [cv2.approxPolyDP(cnt, 0.01 * cv2.arcLength(cnt, True), True) for cnt in contours] # Làm trơn đường nét
44
-
45
- print(f"Number of contours: {len(contours)}")
46
- for i, contour in enumerate(contours):
47
- print(f"Contour {i}: {contour.shape}")
48
 
49
- return contours
50
-
51
- def convert_to_dxf(contours, filename="output.dxf"):
52
- """
53
- Tạo file DXF từ các đường nét (contours).
54
 
55
- Parameters:
56
- contours (list): Danh sách các đường nét.
57
- filename (str): Tên file DXF.
58
 
59
- Returns:
60
- str: Đường dẫn đến file DXF.
61
- """
62
- doc = ezdxf.new(dxfversion="R2010")
63
  msp = doc.modelspace()
64
-
65
- # Thêm tất cả các đường nét vào file DXF
66
  for contour in contours:
67
- points = contour.reshape(-1, 2)
68
- msp.add_lwpolyline(points, close=True)
69
-
70
- doc.saveas(filename)
71
- return filename
72
-
73
- def convert_to_svg(contours, filename="output.svg"):
74
- """
75
- Tạo file SVG từ các đường nét (contours).
76
-
77
- Parameters:
78
- contours (list): Danh sách các đường nét.
79
- filename (str): Tên file SVG.
80
 
81
- Returns:
82
- str: Đường dẫn đến file SVG.
83
- """
84
- root = ET.Element("svg", {
85
- "xmlns": "http://www.w3.org/2000/svg",
86
- "width": "100%",
87
- "height": "100%",
88
- "viewBox": "0 0 100 100"
89
- })
90
-
91
- for contour in contours:
92
- if len(contour) > 1:
93
- path_data = "M"
94
- for point in contour:
95
- path_data += f" {point[0]/100} {point[1]/100}"
96
- path_data += " Z"
97
- path = ET.SubElement(root, "path", {
98
- "d": path_data,
99
- "fill": "none",
100
- "stroke": "black",
101
- "stroke-width": "0.5"
102
- })
103
- elif len(contour) == 1:
104
- # Handle the case where the contour has only one point
105
- point = contour[0]
106
- circle = ET.SubElement(root, "circle", {
107
- "cx": f"{point[0][0]/100}",
108
- "cy": f"{point[0][1]/100}",
109
- "r": "0.5",
110
- "fill": "black"
111
- })
112
 
113
- tree = ET.ElementTree(root)
114
- tree.write(filename)
115
- return filename
116
-
117
- def process_image(img, bw_option):
118
- try:
119
- img = Image.open(img).convert("RGB")
120
- img = np.array(img)
121
- if bw_option:
122
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
123
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
124
- result = inference(u2net, img)
125
- result_img = (result * 255).astype(np.uint8)
126
- contours = extract_contours(result_img)
127
- dxf_path = convert_to_dxf(contours, filename="output.dxf")
128
- svg_path = convert_to_svg(contours, filename="output.svg")
129
- return result_img, dxf_path, svg_path
130
- except UnidentifiedImageError:
131
- return "Error: Unable to identify the image file. Please ensure the input file is a valid image.", None, None
132
 
 
133
  def load_u2net_model():
134
  model_path = hf_hub_download(repo_id="Arrcttacsrks/U2net", filename="u2net_portrait.pth", use_auth_token=os.getenv("HF_TOKEN"))
135
  net = U2NET(3, 1)
@@ -137,25 +61,21 @@ def load_u2net_model():
137
  net.eval()
138
  return net
139
 
 
140
  u2net = load_u2net_model()
141
 
142
- def main():
143
- iface = gr.Interface(
144
- fn=process_image,
145
- inputs=[
146
- gr.Image(type="filepath", label="Upload your image"),
147
- gr.Checkbox(label="Convert to black and white?", value=False)
148
- ],
149
- outputs=[
150
- gr.Image(type="numpy", label="Portrait result"),
151
- gr.File(label="Download DXF file"),
152
- gr.File(label="Download SVG file")
153
- ],
154
- title="Create Portrait Images, DXF and SVG Files from Images",
155
- description="Upload an image to generate a portrait image, a DXF file, and an SVG file from it."
156
- )
157
-
158
- iface.launch(share=True)
159
-
160
- if __name__ == "__main__":
161
- main()
 
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
  import gradio as gr
9
+ import math
10
  import ezdxf
 
 
11
 
12
+ # Chuẩn hóa dự đoán
13
  def normPRED(d):
14
  return (d - torch.min(d)) / (torch.max(d) - torch.min(d))
15
 
16
+ # Hàm suy luận với U2NET
17
  def inference(net, input_img):
18
  input_img = input_img / np.max(input_img)
19
  tmpImg = np.zeros((input_img.shape[0], input_img.shape[1], 3))
 
22
  tmpImg[:, :, 2] = (input_img[:, :, 0] - 0.485) / 0.229
23
  tmpImg = torch.from_numpy(tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]).type(torch.FloatTensor)
24
  tmpImg = Variable(tmpImg.cuda() if torch.cuda.is_available() else tmpImg)
25
+ d1, *, *, *, *, *, * = net(tmpImg)
26
  pred = normPRED(1.0 - d1[:, 0, :, :])
27
  return pred.cpu().data.numpy().squeeze()
28
 
29
+ # Hàm chính để xử lý ảnh đầu vào và trả về ảnh chân dung và DWF file
30
+ def process_image(img, bw_option):
31
+ # Chuyển đổi ảnh thành đen trắng nếu được chọn
32
+ if bw_option:
33
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
34
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # Chuyển lại thành ảnh 3 kênh cho mô hình
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # Chạy suy luận để tạo ảnh chân dung
37
+ result = inference(u2net, img)
 
 
 
38
 
39
+ # Phát hiện và lấy contours từ ảnh chân dung
40
+ _, threshold = cv2.threshold(np.uint8(result * 255), 0, 255, cv2.THRESH_BINARY)
41
+ contours, _ = cv2.findContours(threshold, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
42
 
43
+ # Tạo DWF file từ các contours
44
+ doc = ezdxf.new('R2010')
 
 
45
  msp = doc.modelspace()
 
 
46
  for contour in contours:
47
+ points = [tuple(p[0]) for p in contour]
48
+ msp.add_polyline2d(points)
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ # Lưu DWF file
51
+ dwf_file = 'portrait_result.dxf'
52
+ doc.saveas(dwf_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ return (result * 255).astype(np.uint8), dwf_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Tải mô hình từ Hugging Face Hub
57
  def load_u2net_model():
58
  model_path = hf_hub_download(repo_id="Arrcttacsrks/U2net", filename="u2net_portrait.pth", use_auth_token=os.getenv("HF_TOKEN"))
59
  net = U2NET(3, 1)
 
61
  net.eval()
62
  return net
63
 
64
+ # Khởi tạo mô hình U2NET
65
  u2net = load_u2net_model()
66
 
67
+ # Tạo giao diện với Gradio
68
+ iface = gr.Interface(
69
+ fn=process_image,
70
+ inputs=[
71
+ gr.Image(type="numpy", label="Upload your image"),
72
+ gr.Checkbox(label="Convert to Black & White?", value=False)
73
+ ],
74
+ outputs=[
75
+ gr.Image(type="numpy", label="Portrait Result"),
76
+ gr.File(label="DWF File")
77
+ ],
78
+ title="Portrait Generation with U2NET",
79
+ description="Upload an image to generate its portrait and DWF file."
80
+ )
81
+ iface.launch()