Tsumugii24 commited on
Commit
95be552
1 Parent(s): 9a27fc0

add model auto downloads

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -16,8 +16,9 @@ from ultralytics import YOLO
16
  ROOT_PATH = sys.path[0] # 项目根目录
17
 
18
  fonts_list = ["SimSun.ttf", "TimesNewRoman.ttf", "malgun.ttf"] # 字体列表
 
19
  fonts_directory_path = Path(ROOT_PATH, "fonts")
20
- models_directory_path = Path(ROOT_PATH) # 模型存放在项目的根目录
21
 
22
  data_url_dict = {
23
  "SimSun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/SimSun.ttf",
@@ -37,9 +38,9 @@ model_url_dict = {
37
  def is_fonts(fonts_dir):
38
  if fonts_dir.is_dir():
39
  # 如果本地字体库存在
40
- local_list = os.listdir(fonts_dir) # 本地字体库
41
 
42
- font_diff = list(set(fonts_list).difference(set(local_list)))
43
 
44
  if font_diff != []:
45
  # 缺失字体
@@ -55,19 +56,19 @@ def is_fonts(fonts_dir):
55
  def is_models(models_dir):
56
  if models_dir.is_dir():
57
  # 如果本地模型库存在
58
- local_list = os.listdir(models_dir) # 本地模型库
59
 
60
- model_diff = list(set(model_url_dict.keys()).difference(set(local_list)))
61
 
62
  if model_diff != []:
63
  # 缺失模型
64
  download_models(model_diff) # 下载缺失的模型
65
  else:
66
- print(f"{model_url_dict.keys()}[bold green]Required models already downloaded![/bold green]")
67
  else:
68
  # 本地模型库不存在,创建模型库
69
  print("[bold red]Local models library does not exist, creating now...[/bold red]")
70
- download_models(model_url_dict.keys()) # 创建模型库
71
 
72
  # 下载字体
73
  def download_fonts(font_diff):
@@ -78,9 +79,9 @@ def download_fonts(font_diff):
78
  font_name = v.split("/")[-1] # 字体名称
79
  fonts_directory_path.mkdir(parents=True, exist_ok=True) # 创建目录
80
 
81
- file_path = f"{ROOT_PATH}/fonts/{font_name}" # 字体路径
82
  # 下载字体文件
83
- wget.download(v, file_path)
84
 
85
  # 下载模型
86
  def download_models(model_diff):
@@ -90,9 +91,9 @@ def download_models(model_diff):
90
  v = model_url_dict[k]
91
  model_name = v.split("/")[-1] # 模型名称
92
 
93
- file_path = f"{ROOT_PATH}/{model_name}" # 模型路径
94
  # 下载模型文件
95
- wget.download(v, file_path)
96
 
97
 
98
  is_fonts(fonts_directory_path)
@@ -336,7 +337,7 @@ def seg_output(img_path, seg_mask_list, color_list, cls_list):
336
 
337
 
338
  # 目标检测和图像分割模型加载
339
- def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8n.pt"):
340
  model = YOLO(yolo_model)
341
 
342
  results = model(source=img_path, device=device_opt, imgsz=infer_size, conf=conf, iou=iou, max_det=max_det)
@@ -358,7 +359,7 @@ def yolo_det_img(img_path, model_name, device_opt, infer_size, conf, iou, max_de
358
  cls_index_det_stat = [] # 1
359
 
360
  # 模型加载
361
- predict_results = model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model=f"{model_name}.pt")
362
  # 检测参数
363
  xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
364
  conf_list = predict_results.boxes.conf.cpu().numpy().tolist()
 
16
  ROOT_PATH = sys.path[0] # 项目根目录
17
 
18
  fonts_list = ["SimSun.ttf", "TimesNewRoman.ttf", "malgun.ttf"] # 字体列表
19
+ models_list = ["cnn_se.pt", "detr_based.pt", "vit_based.pt", "yolov5_based.pt", "yolov8_based.pt"] # 模型列表
20
  fonts_directory_path = Path(ROOT_PATH, "fonts")
21
+ models_directory_path = Path(ROOT_PATH, "models") # 模型存放在项目的根目录
22
 
23
  data_url_dict = {
24
  "SimSun.ttf": "https://raw.githubusercontent.com/Tsumugii24/Typora-images/main/files/SimSun.ttf",
 
38
  def is_fonts(fonts_dir):
39
  if fonts_dir.is_dir():
40
  # 如果本地字体库存在
41
+ local_font_list = os.listdir(fonts_dir) # 本地字体库
42
 
43
+ font_diff = list(set(fonts_list).difference(set(local_font_list)))
44
 
45
  if font_diff != []:
46
  # 缺失字体
 
56
  def is_models(models_dir):
57
  if models_dir.is_dir():
58
  # 如果本地模型库存在
59
+ local_model_list = os.listdir(models_dir) # 本地模型库
60
 
61
+ model_diff = list(set(models_list()).difference(set(local_model_list)))
62
 
63
  if model_diff != []:
64
  # 缺失模型
65
  download_models(model_diff) # 下载缺失的模型
66
  else:
67
+ print(f"{models_list}[bold green]Required models already downloaded![/bold green]")
68
  else:
69
  # 本地模型库不存在,创建模型库
70
  print("[bold red]Local models library does not exist, creating now...[/bold red]")
71
+ download_models(models_list) # 创建模型库
72
 
73
  # 下载字体
74
  def download_fonts(font_diff):
 
79
  font_name = v.split("/")[-1] # 字体名称
80
  fonts_directory_path.mkdir(parents=True, exist_ok=True) # 创建目录
81
 
82
+ font_file_path = f"{ROOT_PATH}/fonts/{font_name}" # 字体路径
83
  # 下载字体文件
84
+ wget.download(v, font_file_path)
85
 
86
  # 下载模型
87
  def download_models(model_diff):
 
91
  v = model_url_dict[k]
92
  model_name = v.split("/")[-1] # 模型名称
93
 
94
+ model_file_path = f"{ROOT_PATH}/models/{model_name}" # 模型路径
95
  # 下载模型文件
96
+ wget.download(v, model_file_path)
97
 
98
 
99
  is_fonts(fonts_directory_path)
 
337
 
338
 
339
  # 目标检测和图像分割模型加载
340
+ def model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model="yolov8_based.pt"):
341
  model = YOLO(yolo_model)
342
 
343
  results = model(source=img_path, device=device_opt, imgsz=infer_size, conf=conf, iou=iou, max_det=max_det)
 
359
  cls_index_det_stat = [] # 1
360
 
361
  # 模型加载
362
+ predict_results = model_loading(img_path, device_opt, conf, iou, infer_size, max_det, yolo_model=f"models/{model_name}.pt")
363
  # 检测参数
364
  xyxy_list = predict_results.boxes.xyxy.cpu().numpy().tolist()
365
  conf_list = predict_results.boxes.conf.cpu().numpy().tolist()