''' Author: [egrt] Date: 2022-08-14 09:37:12 LastEditors: [egrt] LastEditTime: 2022-08-17 20:34:36 Description: ''' import numpy as np import pandas as pd import pickle from sklearn.preprocessing import LabelEncoder def show_config(**kwargs): print('Configurations:') print('-' * 70) print('|%25s | %40s|' % ('keys', 'values')) print('-' * 70) for key, value in kwargs.items(): print('|%25s | %40s|' % (str(key), str(value))) print('-' * 70) #--------------------------------------------# # 使用自己训练好的模型预测需要修改3个参数 # model_path和classes_path和backbone都需要修改! #--------------------------------------------# class Classification(object): _defaults = { #--------------------------------------------------------------------------# # 使用自己训练好的模型进行预测一定要修改model_path和classes_path! # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 #--------------------------------------------------------------------------# "model_path" : 'model_data/automl_v2.pkl', "train_path" : 'datasets/archive/artworks.csv', #-------------------------------# # 是否使用Cuda # 没有GPU可以设置成False #-------------------------------# "cuda" : False } @classmethod def get_defaults(cls, n): if n in cls._defaults: return cls._defaults[n] else: return "Unrecognized attribute name '" + n + "'" #---------------------------------------------------# # 初始化classification #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) for name, value in kwargs.items(): setattr(self, name, value) #---------------------------------------------------# # 获得种类 #---------------------------------------------------# self.num_classes = 1 self.train_data = pd.read_csv(self.train_path) self.generate() show_config(**self._defaults) #---------------------------------------------------# # 获得所有的分类 #---------------------------------------------------# def generate(self): #---------------------------------------------------# # 载入模型与权值 #---------------------------------------------------# with open("model_data/automl_v2.pkl", "rb") as f: self.automl = pickle.load(f) def detect_one(self, name, date, level, classification, height, width): # 读取数据集 train_data = self.train_data ArtistID = train_data.loc[train_data["Name"] == name, "Artist ID"][0] # 对输入数据进行编码 la_Catalogue = LabelEncoder() la_Catalogue.fit(train_data["Catalogue"]) Catalogue = la_Catalogue.transform(["Y"]) la_Department = LabelEncoder() la_Department.fit(train_data["Department"]) Department = la_Department.transform([level]) la_Classification = LabelEncoder() la_Classification.fit(train_data["Classification"]) Classification = la_Classification.transform([classification]) test_dict = {'Artist ID':ArtistID,'Date':date,'Department':Department,'Classification':Classification, "Height (cm)":height, "Width (cm)":width} test_data = pd.DataFrame(test_dict) pred = self.automl.predict(test_data) return int(pred[0]) if __name__ == "__main__": classfication = Classification() # classfication.get_result() classfication.detect_one("陈冠夫", 1975, '国家级', '中国山水画', 50, 50)