lycaoduong commited on
Commit
7fcf77d
1 Parent(s): b4b217e

Upload engine.py

Browse files

Fixed name errors

Files changed (1) hide show
  1. FcgEngine/engine.py +4 -2
FcgEngine/engine.py CHANGED
@@ -59,7 +59,8 @@ class PredictorCls(object):
59
  self.model = AutoModelForImageClassification.from_pretrained(model_path, trust_remote_code=True)
60
  self.model.to(device)
61
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
62
- self.ids = list(config.cls_name.keys())
 
63
  self.device = device
64
 
65
  def __call__(self, spectra):
@@ -78,7 +79,8 @@ class PredictorCls(object):
78
  fcn_groups = []
79
  probabilities = []
80
  for ids in predict_cls:
81
- cls_name = self.ids[ids]
 
82
  prob = result[ids]
83
  fcn_groups.append(cls_name)
84
  probabilities.append(prob)
 
59
  self.model = AutoModelForImageClassification.from_pretrained(model_path, trust_remote_code=True)
60
  self.model.to(device)
61
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
62
+ self.name_list = list(config.cls_name.keys())
63
+ self.id_list = list(config.cls_name.values())
64
  self.device = device
65
 
66
  def __call__(self, spectra):
 
79
  fcn_groups = []
80
  probabilities = []
81
  for ids in predict_cls:
82
+ position = self.id_list.index(ids)
83
+ cls_name = self.name_list[position]
84
  prob = result[ids]
85
  fcn_groups.append(cls_name)
86
  probabilities.append(prob)