Jiahuita commited on
Commit
c356db2
1 Parent(s): 13ad768

Update pipeline to resolve tranformer issue

Browse files
Files changed (1) hide show
  1. pipeline.py +28 -7
pipeline.py CHANGED
@@ -2,6 +2,7 @@ from transformers import PreTrainedModel, PretrainedConfig
2
  from tensorflow.keras.models import load_model
3
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
4
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
5
  import numpy as np
6
  import json
7
 
@@ -17,15 +18,35 @@ class NewsClassifier(PreTrainedModel):
17
 
18
  def __init__(self, config):
19
  super().__init__(config)
20
- self.model = load_model('./news_classifier.h5')
21
- with open('./tokenizer.json', 'r') as f:
 
 
 
22
  tokenizer_data = json.load(f)
23
  self.tokenizer = tokenizer_from_json(tokenizer_data)
 
 
 
 
 
 
24
 
25
- def forward(self, inputs):
26
- sequences = self.tokenizer.texts_to_sequences([inputs])
27
  padded = pad_sequences(sequences, maxlen=self.config.max_length)
28
  predictions = self.model.predict(padded)
29
- scores = predictions[0]
30
- label = "foxnews" if scores[0] > 0.5 else "nbc"
31
- return {"label": label, "score": float(scores[0] if label == "foxnews" else 1 - scores[0])}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from tensorflow.keras.models import load_model
3
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
4
  from tensorflow.keras.preprocessing.sequence import pad_sequences
5
+ import os
6
  import numpy as np
7
  import json
8
 
 
18
 
19
  def __init__(self, config):
20
  super().__init__(config)
21
+ model_path = os.path.join(os.path.dirname(__file__), 'news_classifier.h5')
22
+ tokenizer_path = os.path.join(os.path.dirname(__file__), 'tokenizer.json')
23
+
24
+ self.model = load_model(model_path)
25
+ with open(tokenizer_path, 'r') as f:
26
  tokenizer_data = json.load(f)
27
  self.tokenizer = tokenizer_from_json(tokenizer_data)
28
+
29
+ def forward(self, text_input):
30
+ if isinstance(text_input, str):
31
+ sequences = self.tokenizer.texts_to_sequences([text_input])
32
+ else:
33
+ sequences = self.tokenizer.texts_to_sequences(text_input)
34
 
 
 
35
  padded = pad_sequences(sequences, maxlen=self.config.max_length)
36
  predictions = self.model.predict(padded)
37
+
38
+ results = []
39
+ for score in predictions:
40
+ label = "foxnews" if score[0] > 0.5 else "nbc"
41
+ results.append({
42
+ "label": label,
43
+ "score": float(score[0] if label == "foxnews" else 1 - score[0])
44
+ })
45
+
46
+ return results[0] if isinstance(text_input, str) else results
47
+
48
+ @classmethod
49
+ def from_pretrained(cls, model_path, **kwargs):
50
+ config = NewsClassifierConfig.from_pretrained(model_path)
51
+ model = cls(config)
52
+ return model