Jiahuita commited on
Commit
13ad768
·
1 Parent(s): d6f2234

Updated pipeline

Browse files
Files changed (1) hide show
  1. pipeline.py +18 -20
pipeline.py CHANGED
@@ -1,33 +1,31 @@
1
- from transformers import Pipeline
2
  from tensorflow.keras.models import load_model
3
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
4
  from tensorflow.keras.preprocessing.sequence import pad_sequences
5
  import numpy as np
6
  import json
7
 
8
- class NewsClassifierPipeline(Pipeline):
9
- def __init__(self):
10
- super().__init__()
11
- # Load model and tokenizer
 
 
 
 
 
 
 
 
12
  self.model = load_model('./news_classifier.h5')
13
  with open('./tokenizer.json', 'r') as f:
14
  tokenizer_data = json.load(f)
15
  self.tokenizer = tokenizer_from_json(tokenizer_data)
16
-
17
- def preprocess(self, inputs):
18
- """Tokenizes and pads the input text."""
19
  sequences = self.tokenizer.texts_to_sequences([inputs])
20
- padded = pad_sequences(sequences, maxlen=128)
21
- return padded
22
-
23
- def _forward(self, inputs):
24
- """Runs the model prediction."""
25
- processed = self.preprocess(inputs)
26
- predictions = self.model.predict(processed)
27
  scores = predictions[0]
28
  label = "foxnews" if scores[0] > 0.5 else "nbc"
29
- return [{"label": label, "score": float(scores[0] if label == "foxnews" else 1 - scores[0])}]
30
-
31
- def postprocess(self, model_outputs):
32
- """Returns the processed output."""
33
- return model_outputs
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
  from tensorflow.keras.models import load_model
3
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
4
  from tensorflow.keras.preprocessing.sequence import pad_sequences
5
  import numpy as np
6
  import json
7
 
8
+ class NewsClassifierConfig(PretrainedConfig):
9
+ model_type = "news_classifier"
10
+
11
+ def __init__(self, max_length=128, **kwargs):
12
+ self.max_length = max_length
13
+ super().__init__(**kwargs)
14
+
15
+ class NewsClassifier(PreTrainedModel):
16
+ config_class = NewsClassifierConfig
17
+
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
  self.model = load_model('./news_classifier.h5')
21
  with open('./tokenizer.json', 'r') as f:
22
  tokenizer_data = json.load(f)
23
  self.tokenizer = tokenizer_from_json(tokenizer_data)
24
+
25
+ def forward(self, inputs):
 
26
  sequences = self.tokenizer.texts_to_sequences([inputs])
27
+ padded = pad_sequences(sequences, maxlen=self.config.max_length)
28
+ predictions = self.model.predict(padded)
 
 
 
 
 
29
  scores = predictions[0]
30
  label = "foxnews" if scores[0] > 0.5 else "nbc"
31
+ return {"label": label, "score": float(scores[0] if label == "foxnews" else 1 - scores[0])}