Jiahuita commited on
Commit
838a3ce
1 Parent(s): 5fda167

Add custom pipeline and fix configs

Browse files
Files changed (2) hide show
  1. pipeline.py +11 -6
  2. requirements.txt +0 -1
pipeline.py CHANGED
@@ -8,21 +8,26 @@ import json
8
  class NewsClassifierPipeline(Pipeline):
9
  def __init__(self):
10
  super().__init__()
 
11
  self.model = load_model('./news_classifier.h5')
12
  with open('./tokenizer.json', 'r') as f:
13
  tokenizer_data = json.load(f)
14
  self.tokenizer = tokenizer_from_json(tokenizer_data)
15
 
16
  def preprocess(self, inputs):
 
17
  sequences = self.tokenizer.texts_to_sequences([inputs])
18
- return pad_sequences(sequences, maxlen=128)
 
19
 
20
  def _forward(self, inputs):
 
21
  processed = self.preprocess(inputs)
22
  predictions = self.model.predict(processed)
23
- label = "foxnews" if predictions[0][0] > 0.5 else "nbc"
24
- score = predictions[0][0] if label == "foxnews" else 1 - predictions[0][0]
25
- return [{"label": label, "score": float(score)}]
26
 
27
- def postprocess(self, outputs):
28
- return outputs
 
 
8
  class NewsClassifierPipeline(Pipeline):
9
  def __init__(self):
10
  super().__init__()
11
+ # Load model and tokenizer
12
  self.model = load_model('./news_classifier.h5')
13
  with open('./tokenizer.json', 'r') as f:
14
  tokenizer_data = json.load(f)
15
  self.tokenizer = tokenizer_from_json(tokenizer_data)
16
 
17
  def preprocess(self, inputs):
18
+ """Tokenizes and pads the input text."""
19
  sequences = self.tokenizer.texts_to_sequences([inputs])
20
+ padded = pad_sequences(sequences, maxlen=128)
21
+ return padded
22
 
23
  def _forward(self, inputs):
24
+ """Runs the model prediction."""
25
  processed = self.preprocess(inputs)
26
  predictions = self.model.predict(processed)
27
+ scores = predictions[0]
28
+ label = "foxnews" if scores[0] > 0.5 else "nbc"
29
+ return [{"label": label, "score": float(scores[0] if label == "foxnews" else 1 - scores[0])}]
30
 
31
+ def postprocess(self, model_outputs):
32
+ """Returns the processed output."""
33
+ return model_outputs
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  tensorflow>=2.10.0
2
  transformers>=4.30.0
3
- torch>=2.0.0
4
  numpy>=1.19.2
5
  scikit-learn>=0.24.2
 
1
  tensorflow>=2.10.0
2
  transformers>=4.30.0
 
3
  numpy>=1.19.2
4
  scikit-learn>=0.24.2