Jiahuita
commited on
Commit
•
838a3ce
1
Parent(s):
5fda167
Add custom pipeline and fix configs
Browse files- pipeline.py +11 -6
- requirements.txt +0 -1
pipeline.py
CHANGED
@@ -8,21 +8,26 @@ import json
|
|
8 |
class NewsClassifierPipeline(Pipeline):
|
9 |
def __init__(self):
|
10 |
super().__init__()
|
|
|
11 |
self.model = load_model('./news_classifier.h5')
|
12 |
with open('./tokenizer.json', 'r') as f:
|
13 |
tokenizer_data = json.load(f)
|
14 |
self.tokenizer = tokenizer_from_json(tokenizer_data)
|
15 |
|
16 |
def preprocess(self, inputs):
|
|
|
17 |
sequences = self.tokenizer.texts_to_sequences([inputs])
|
18 |
-
|
|
|
19 |
|
20 |
def _forward(self, inputs):
|
|
|
21 |
processed = self.preprocess(inputs)
|
22 |
predictions = self.model.predict(processed)
|
23 |
-
|
24 |
-
|
25 |
-
return [{"label": label, "score": float(
|
26 |
|
27 |
-
def postprocess(self,
|
28 |
-
|
|
|
|
8 |
class NewsClassifierPipeline(Pipeline):
|
9 |
def __init__(self):
|
10 |
super().__init__()
|
11 |
+
# Load model and tokenizer
|
12 |
self.model = load_model('./news_classifier.h5')
|
13 |
with open('./tokenizer.json', 'r') as f:
|
14 |
tokenizer_data = json.load(f)
|
15 |
self.tokenizer = tokenizer_from_json(tokenizer_data)
|
16 |
|
17 |
def preprocess(self, inputs):
|
18 |
+
"""Tokenizes and pads the input text."""
|
19 |
sequences = self.tokenizer.texts_to_sequences([inputs])
|
20 |
+
padded = pad_sequences(sequences, maxlen=128)
|
21 |
+
return padded
|
22 |
|
23 |
def _forward(self, inputs):
|
24 |
+
"""Runs the model prediction."""
|
25 |
processed = self.preprocess(inputs)
|
26 |
predictions = self.model.predict(processed)
|
27 |
+
scores = predictions[0]
|
28 |
+
label = "foxnews" if scores[0] > 0.5 else "nbc"
|
29 |
+
return [{"label": label, "score": float(scores[0] if label == "foxnews" else 1 - scores[0])}]
|
30 |
|
31 |
+
def postprocess(self, model_outputs):
|
32 |
+
"""Returns the processed output."""
|
33 |
+
return model_outputs
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
tensorflow>=2.10.0
|
2 |
transformers>=4.30.0
|
3 |
-
torch>=2.0.0
|
4 |
numpy>=1.19.2
|
5 |
scikit-learn>=0.24.2
|
|
|
1 |
tensorflow>=2.10.0
|
2 |
transformers>=4.30.0
|
|
|
3 |
numpy>=1.19.2
|
4 |
scikit-learn>=0.24.2
|