Artwork_Valuation / AutoML.py
白鹭先生
init
c9843cd
'''
Author: [egrt]
Date: 2022-08-14 09:37:12
LastEditors: [egrt]
LastEditTime: 2022-08-17 20:34:36
Description:
'''
import numpy as np
import pandas as pd
import pickle
from sklearn.preprocessing import LabelEncoder
def show_config(**kwargs):
print('Configurations:')
print('-' * 70)
print('|%25s | %40s|' % ('keys', 'values'))
print('-' * 70)
for key, value in kwargs.items():
print('|%25s | %40s|' % (str(key), str(value)))
print('-' * 70)
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改3个参数
# model_path和classes_path和backbone都需要修改!
#--------------------------------------------#
class Classification(object):
_defaults = {
#--------------------------------------------------------------------------#
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
#--------------------------------------------------------------------------#
"model_path" : 'model_data/automl_v2.pkl',
"train_path" : 'datasets/archive/artworks.csv',
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : False
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化classification
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
#---------------------------------------------------#
# 获得种类
#---------------------------------------------------#
self.num_classes = 1
self.train_data = pd.read_csv(self.train_path)
self.generate()
show_config(**self._defaults)
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self):
#---------------------------------------------------#
# 载入模型与权值
#---------------------------------------------------#
with open("model_data/automl_v2.pkl", "rb") as f:
self.automl = pickle.load(f)
def detect_one(self, name, date, level, classification, height, width):
# 读取数据集
train_data = self.train_data
ArtistID = train_data.loc[train_data["Name"] == name, "Artist ID"][0]
# 对输入数据进行编码
la_Catalogue = LabelEncoder()
la_Catalogue.fit(train_data["Catalogue"])
Catalogue = la_Catalogue.transform(["Y"])
la_Department = LabelEncoder()
la_Department.fit(train_data["Department"])
Department = la_Department.transform([level])
la_Classification = LabelEncoder()
la_Classification.fit(train_data["Classification"])
Classification = la_Classification.transform([classification])
test_dict = {'Artist ID':ArtistID,'Date':date,'Department':Department,'Classification':Classification,
"Height (cm)":height, "Width (cm)":width}
test_data = pd.DataFrame(test_dict)
pred = self.automl.predict(test_data)
return int(pred[0])
if __name__ == "__main__":
classfication = Classification()
# classfication.get_result()
classfication.detect_one("陈冠夫", 1975, '国家级', '中国山水画', 50, 50)