File size: 3,964 Bytes
c9843cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
'''
Author: [egrt]
Date: 2022-08-14 09:37:12
LastEditors: [egrt]
LastEditTime: 2022-08-17 20:34:36
Description: 
'''
import numpy as np
import pandas as pd
import pickle
from sklearn.preprocessing import LabelEncoder

def show_config(**kwargs):
    print('Configurations:')
    print('-' * 70)
    print('|%25s | %40s|' % ('keys', 'values'))
    print('-' * 70)
    for key, value in kwargs.items():
        print('|%25s | %40s|' % (str(key), str(value)))
    print('-' * 70)
    
#--------------------------------------------#
#   使用自己训练好的模型预测需要修改3个参数
#   model_path和classes_path和backbone都需要修改!
#--------------------------------------------#
class Classification(object):
    _defaults = {
        #--------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测一定要修改model_path和classes_path!
        #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
        #   如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
        #--------------------------------------------------------------------------#
        "model_path"        : 'model_data/automl_v2.pkl',
        "train_path"        : 'datasets/archive/artworks.csv',
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda"              : False
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化classification
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)

        #---------------------------------------------------#
        #   获得种类
        #---------------------------------------------------#
        self.num_classes = 1
        self.train_data  = pd.read_csv(self.train_path)
        self.generate()
        
        show_config(**self._defaults)

    #---------------------------------------------------#
    #   获得所有的分类
    #---------------------------------------------------#
    def generate(self):
        #---------------------------------------------------#
        #   载入模型与权值
        #---------------------------------------------------#
        with open("model_data/automl_v2.pkl", "rb") as f:
            self.automl = pickle.load(f)

    def detect_one(self, name, date, level, classification, height, width):
        # 读取数据集
        train_data   = self.train_data
        ArtistID     = train_data.loc[train_data["Name"] == name, "Artist ID"][0]
        # 对输入数据进行编码
        la_Catalogue = LabelEncoder()
        la_Catalogue.fit(train_data["Catalogue"])
        Catalogue = la_Catalogue.transform(["Y"])

        la_Department = LabelEncoder()
        la_Department.fit(train_data["Department"])
        Department = la_Department.transform([level])

        la_Classification = LabelEncoder()
        la_Classification.fit(train_data["Classification"])
        Classification = la_Classification.transform([classification])
        test_dict = {'Artist ID':ArtistID,'Date':date,'Department':Department,'Classification':Classification,
            "Height (cm)":height, "Width (cm)":width}
        test_data = pd.DataFrame(test_dict)

        pred = self.automl.predict(test_data)
        return int(pred[0])

if __name__ == "__main__":
    classfication = Classification()
    # classfication.get_result()
    classfication.detect_one("陈冠夫", 1975, '国家级', '中国山水画', 50, 50)