Spaces:
Runtime error
Runtime error
''' | |
Author: [egrt] | |
Date: 2022-08-14 09:37:12 | |
LastEditors: [egrt] | |
LastEditTime: 2022-08-17 20:34:36 | |
Description: | |
''' | |
import numpy as np | |
import pandas as pd | |
import pickle | |
from sklearn.preprocessing import LabelEncoder | |
def show_config(**kwargs): | |
print('Configurations:') | |
print('-' * 70) | |
print('|%25s | %40s|' % ('keys', 'values')) | |
print('-' * 70) | |
for key, value in kwargs.items(): | |
print('|%25s | %40s|' % (str(key), str(value))) | |
print('-' * 70) | |
#--------------------------------------------# | |
# 使用自己训练好的模型预测需要修改3个参数 | |
# model_path和classes_path和backbone都需要修改! | |
#--------------------------------------------# | |
class Classification(object): | |
_defaults = { | |
#--------------------------------------------------------------------------# | |
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path! | |
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt | |
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改 | |
#--------------------------------------------------------------------------# | |
"model_path" : 'model_data/automl_v2.pkl', | |
"train_path" : 'datasets/archive/artworks.csv', | |
#-------------------------------# | |
# 是否使用Cuda | |
# 没有GPU可以设置成False | |
#-------------------------------# | |
"cuda" : False | |
} | |
def get_defaults(cls, n): | |
if n in cls._defaults: | |
return cls._defaults[n] | |
else: | |
return "Unrecognized attribute name '" + n + "'" | |
#---------------------------------------------------# | |
# 初始化classification | |
#---------------------------------------------------# | |
def __init__(self, **kwargs): | |
self.__dict__.update(self._defaults) | |
for name, value in kwargs.items(): | |
setattr(self, name, value) | |
#---------------------------------------------------# | |
# 获得种类 | |
#---------------------------------------------------# | |
self.num_classes = 1 | |
self.train_data = pd.read_csv(self.train_path) | |
self.generate() | |
show_config(**self._defaults) | |
#---------------------------------------------------# | |
# 获得所有的分类 | |
#---------------------------------------------------# | |
def generate(self): | |
#---------------------------------------------------# | |
# 载入模型与权值 | |
#---------------------------------------------------# | |
with open("model_data/automl_v2.pkl", "rb") as f: | |
self.automl = pickle.load(f) | |
def detect_one(self, name, date, level, classification, height, width): | |
# 读取数据集 | |
train_data = self.train_data | |
ArtistID = train_data.loc[train_data["Name"] == name, "Artist ID"][0] | |
# 对输入数据进行编码 | |
la_Catalogue = LabelEncoder() | |
la_Catalogue.fit(train_data["Catalogue"]) | |
Catalogue = la_Catalogue.transform(["Y"]) | |
la_Department = LabelEncoder() | |
la_Department.fit(train_data["Department"]) | |
Department = la_Department.transform([level]) | |
la_Classification = LabelEncoder() | |
la_Classification.fit(train_data["Classification"]) | |
Classification = la_Classification.transform([classification]) | |
test_dict = {'Artist ID':ArtistID,'Date':date,'Department':Department,'Classification':Classification, | |
"Height (cm)":height, "Width (cm)":width} | |
test_data = pd.DataFrame(test_dict) | |
pred = self.automl.predict(test_data) | |
return int(pred[0]) | |
if __name__ == "__main__": | |
classfication = Classification() | |
# classfication.get_result() | |
classfication.detect_one("陈冠夫", 1975, '国家级', '中国山水画', 50, 50) |