merve HF staff commited on
Commit
eb496ad
1 Parent(s): a7ce59e

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +12 -9
pipeline.py CHANGED
@@ -18,16 +18,17 @@ class PreTrainedPipeline(Pipeline):
18
 
19
 
20
  # Reload Keras SavedModel
21
- self.model = from_pretrained_keras(model_id)
22
 
23
  # Number of labels
24
  self.num_labels = self.model.output_shape[1]
25
 
26
  # Config is required to know the mapping to label.
27
- config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME)
28
- with open(config_file) as config:
29
- config = json.load(config)
30
-
 
31
  self.id2label = config.get(
32
  "id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)}
33
  )
@@ -59,12 +60,14 @@ class PreTrainedPipeline(Pipeline):
59
  self.single_output_unit = (
60
  self.model.output_shape[1] == 1
61
  ) # if there are two classes
62
-
 
63
  if self.single_output_unit:
64
  score = predictions[0][0]
65
- labels = [
66
- {"label": str(self.id2label["1"]), "score": float(score)},
67
- {"label": str(self.id2label["0"]), "score": float(1 - score)},
 
68
  ]
69
  else:
70
  labels = [
 
18
 
19
 
20
  # Reload Keras SavedModel
21
+ self.model = keras.models.load_model('./model.h5')
22
 
23
  # Number of labels
24
  self.num_labels = self.model.output_shape[1]
25
 
26
  # Config is required to know the mapping to label.
27
+ #config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME)
28
+ #with open(config_file) as config:
29
+ # config = json.load(config)
30
+
31
+ self.num_labels = 3
32
  self.id2label = config.get(
33
  "id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)}
34
  )
 
60
  self.single_output_unit = (
61
  self.model.output_shape[1] == 1
62
  ) # if there are two classes
63
+
64
+
65
  if self.single_output_unit:
66
  score = predictions[0][0]
67
+ labels = [{"label":"pet", "score":1.0}, {"label":"other", "score":1.0}]
68
+ #labels = [
69
+ # {"label": str(self.id2label["1"]), "score": float(score)},
70
+ # {"label": str(self.id2label["0"]), "score": float(1 - score)},
71
  ]
72
  else:
73
  labels = [