Spaces:
Runtime error
Runtime error
Tsumugii24
commited on
Commit
•
95be552
1
Parent(s):
9a27fc0
add model auto downloads
Browse files
app.py
CHANGED
@@ -16,8 +16,9 @@ from ultralytics import YOLO
|
|
16 |
ROOT_PATH = sys.path[0] # 项目根目录
|
17 |
|
18 |
fonts_list = ["SimSun.ttf", "TimesNewRoman.ttf", "malgun.ttf"] # 字体列表
|
|
|
19 |
fonts_directory_path = Path(ROOT_PATH, "fonts")
|
20 |
-
models_directory_path = Path(ROOT_PATH) # 模型存放在项目的根目录
|
21 |
|
22 |
data_url_dict = {
|
23 |
"SimSun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/SimSun.ttf",
|
@@ -37,9 +38,9 @@ model_url_dict = {
|
|
37 |
def is_fonts(fonts_dir):
|
38 |
if fonts_dir.is_dir():
|
39 |
# 如果本地字体库存在
|
40 |
-
|
41 |
|
42 |
-
font_diff = list(set(fonts_list).difference(set(
|
43 |
|
44 |
if font_diff != []:
|
45 |
# 缺失字体
|
@@ -55,19 +56,19 @@ def is_fonts(fonts_dir):
|
|
55 |
def is_models(models_dir):
|
56 |
if models_dir.is_dir():
|
57 |
# 如果本地模型库存在
|
58 |
-
|
59 |
|
60 |
-
model_diff = list(set(
|
61 |
|
62 |
if model_diff != []:
|
63 |
# 缺失模型
|
64 |
download_models(model_diff) # 下载缺失的模型
|
65 |
else:
|
66 |
-
print(f"{
|
67 |
else:
|
68 |
# 本地模型库不存在,创建模型库
|
69 |
print("[bold red]Local models library does not exist, creating now...[/bold red]")
|
70 |
-
download_models(
|
71 |
|
72 |
# 下载字体
|
73 |
def download_fonts(font_diff):
|
@@ -78,9 +79,9 @@ def download_fonts(font_diff):
|
|
78 |
font_name = v.split("/")[-1] # 字体名称
|
79 |
fonts_directory_path.mkdir(parents=True, exist_ok=True) # 创建目录
|
80 |
|
81 |
-
|
82 |
# 下载字体文件
|
83 |
-
wget.download(v,
|
84 |
|
85 |
# 下载模型
|
86 |
def download_models(model_diff):
|
@@ -90,9 +91,9 @@ def download_models(model_diff):
|
|
90 |
v = model_url_dict[k]
|
91 |
model_name = v.split("/")[-1] # 模型名称
|
92 |
|
93 |
-
|
94 |
# 下载模型文件
|
95 |
-
wget.download(v,
|
96 |
|
97 |
|
98 |
is_fonts(fonts_directory_path)
|
@@ -336,7 +337,7 @@ def seg_output(img_path, seg_mask_list, color_list, cls_list):
|
|
336 |
|
337 |
|
338 |
# 目标检测和图像分割模型加载
|
339 |
-
def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="
|
340 |
model = YOLO(yolo_model)
|
341 |
|
342 |
results = model(source=img_path, device=device_opt, imgsz=infer_size, conf=conf, iou=iou, max_det=max_det)
|
@@ -358,7 +359,7 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
|
|
358 |
cls_index_det_stat = [] # 1
|
359 |
|
360 |
# 模型加载
|
361 |
-
predict_results = model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model=f"{model_name}.pt")
|
362 |
# 检测参数
|
363 |
xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
|
364 |
conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
|
|
|
16 |
ROOT_PATH = sys.path[0] # 项目根目录
|
17 |
|
18 |
fonts_list = ["SimSun.ttf", "TimesNewRoman.ttf", "malgun.ttf"] # 字体列表
|
19 |
+
models_list = ["cnn_se.pt", "detr_based.pt", "vit_based.pt", "yolov5_based.pt", "yolov8_based.pt"] # 模型列表
|
20 |
fonts_directory_path = Path(ROOT_PATH, "fonts")
|
21 |
+
models_directory_path = Path(ROOT_PATH, "models") # 模型存放在项目的根目录
|
22 |
|
23 |
data_url_dict = {
|
24 |
"SimSun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/SimSun.ttf",
|
|
|
38 |
def is_fonts(fonts_dir):
|
39 |
if fonts_dir.is_dir():
|
40 |
# 如果本地字体库存在
|
41 |
+
local_font_list = os.listdir(fonts_dir) # 本地字体库
|
42 |
|
43 |
+
font_diff = list(set(fonts_list).difference(set(local_font_list)))
|
44 |
|
45 |
if font_diff != []:
|
46 |
# 缺失字体
|
|
|
56 |
def is_models(models_dir):
|
57 |
if models_dir.is_dir():
|
58 |
# 如果本地模型库存在
|
59 |
+
local_model_list = os.listdir(models_dir) # 本地模型库
|
60 |
|
61 |
+
model_diff = list(set(models_list()).difference(set(local_model_list)))
|
62 |
|
63 |
if model_diff != []:
|
64 |
# 缺失模型
|
65 |
download_models(model_diff) # 下载缺失的模型
|
66 |
else:
|
67 |
+
print(f"{models_list}[bold green]Required models already downloaded![/bold green]")
|
68 |
else:
|
69 |
# 本地模型库不存在,创建模型库
|
70 |
print("[bold red]Local models library does not exist, creating now...[/bold red]")
|
71 |
+
download_models(models_list) # 创建模型库
|
72 |
|
73 |
# 下载字体
|
74 |
def download_fonts(font_diff):
|
|
|
79 |
font_name = v.split("/")[-1] # 字体名称
|
80 |
fonts_directory_path.mkdir(parents=True, exist_ok=True) # 创建目录
|
81 |
|
82 |
+
font_file_path = f"{ROOT_PATH}/fonts/{font_name}" # 字体路径
|
83 |
# 下载字体文件
|
84 |
+
wget.download(v, font_file_path)
|
85 |
|
86 |
# 下载模型
|
87 |
def download_models(model_diff):
|
|
|
91 |
v = model_url_dict[k]
|
92 |
model_name = v.split("/")[-1] # 模型名称
|
93 |
|
94 |
+
model_file_path = f"{ROOT_PATH}/models/{model_name}" # 模型路径
|
95 |
# 下载模型文件
|
96 |
+
wget.download(v, model_file_path)
|
97 |
|
98 |
|
99 |
is_fonts(fonts_directory_path)
|
|
|
337 |
|
338 |
|
339 |
# 目标检测和图像分割模型加载
|
340 |
+
def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8_based.pt"):
|
341 |
model = YOLO(yolo_model)
|
342 |
|
343 |
results = model(source=img_path, device=device_opt, imgsz=infer_size, conf=conf, iou=iou, max_det=max_det)
|
|
|
359 |
cls_index_det_stat = [] # 1
|
360 |
|
361 |
# 模型加载
|
362 |
+
predict_results = model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model=f"models/{model_name}.pt")
|
363 |
# 检测参数
|
364 |
xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
|
365 |
conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
|