File size: 4,222 Bytes
33e71b2
 
 
 
 
 
215e62d
 
 
 
 
33e71b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Streamlit YOLOv5 Model2X v0.1
# 创建人:曾逸夫
# 创建时间:2022-07-14
# 功能描述:多选,多项模型转换和打包下载

import os

os.system("pip install pip -U")
os.system("pip install nvidia-pyindex")
os.system("pip install nvidia-tensorrt")

import shutil
import time
import zipfile

import streamlit as st


# 目录操作
def dir_opt(target_dir):
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
        os.mkdir(target_dir)
    else:
        os.mkdir(target_dir)


# 文件下载
def download_file(uploaded_file):
    # --------------- 下载 ---------------
    with open(f"{uploaded_file}", 'rb') as fmodel:
        # 读取转换的模型文件(pt2x)
        f_download_model = fmodel.read()
    st.download_button(label='下载转换后的模型', data=f_download_model, file_name=f"{uploaded_file}")
    fmodel.close()


# 文件压缩
def zipDir(origin_dir, compress_file):
    # --------------- 压缩 ---------------
    zip = zipfile.ZipFile(f"{compress_file}", "w", zipfile.ZIP_DEFLATED)
    for path, dirnames, filenames in os.walk(f"{origin_dir}"):
        fpath = path.replace(f"{origin_dir}", '')
        for filename in filenames:
            zip.write(os.path.join(path, filename), os.path.join(fpath, filename))
    zip.close()


# params_include_list = ["torchscript", "onnx", "openvino", "engine", "coreml", "saved_model", "pb", "tflite", "tfjs"]
def cb_opt(weight_name, btn_model_list, params_include_list):

    for i in range(len(btn_model_list)):
        if btn_model_list[i]:
            st.info(f"正在转换{params_include_list[i]}......")
            s = time.time()
            if i == 3:
                os.system(
                    f'python export.py --weights ./weights/{weight_name} --include {params_include_list[i]} --device 0')
            else:
                os.system(f'python export.py --weights ./weights/{weight_name} --include {params_include_list[i]}')
            e = time.time()
            st.success(f"{params_include_list[i]}转换完成,用时{round((e-s), 2)}秒")

    zipDir("./weights", "convert_weights.zip")  # 打包weights目录,包括原始权重和转换后的权重
    download_file("convert_weights.zip")  # 下载打包文件


def main():
    with st.container():
        st.title("Streamlit YOLOv5 Model2X")
        st.subheader('创建人:曾逸夫(Zeng Yifu)')
        st.text("基于Streamlit的YOLOv5模型转换工具")

        st.write("-------------------------------------------------------------")

        dir_opt("./weights")

        uploaded_file = st.file_uploader("选择YOLOv5模型文件(.pt)")
        if uploaded_file is not None:

            # 读取上传的模型文件(.pt)
            weight_name = uploaded_file.name

            st.info(f"正在写入{weight_name}......")

            bytes_data = uploaded_file.getvalue()
            with open(f"./weights/{weight_name}", 'wb') as fb:
                fb.write(bytes_data)
            fb.close()
            st.success(f"{weight_name}写入成功!")

            st.text("请选择转换的类型:")
            cb_torchscript = st.checkbox('TorchScript')
            cb_onnx = st.checkbox('ONNX')
            cb_openvino = st.checkbox('OpenVINO')
            cb_engine = st.checkbox('TensorRT')
            cb_coreml = st.checkbox('CoreML')
            cb_saved_model = st.checkbox('TensorFlow SavedModel')
            cb_pb = st.checkbox('TensorFlow GraphDef')
            cb_tflite = st.checkbox('TensorFlow Lite')
            # cb_edgetpu = st.checkbox('TensorFlow Edge TPU')
            cb_tfjs = st.checkbox('TensorFlow.js')

            btn_convert = st.button('转换')

            btn_model_list = [
                cb_torchscript, cb_onnx, cb_openvino, cb_engine, cb_coreml, cb_saved_model, cb_pb, cb_tflite, cb_tfjs]

            params_include_list = [
                "torchscript", "onnx", "openvino", "engine", "coreml", "saved_model", "pb", "tflite", "tfjs"]

            if btn_convert:
                cb_opt(weight_name, btn_model_list, params_include_list)

    st.write("-------------------------------------------------------------")


if __name__ == "__main__":
    main()