Jiahuita commited on
Commit
5aafe28
·
1 Parent(s): 8cc42bc

Modified app and readme

Browse files
Files changed (2) hide show
  1. README.md +37 -2
  2. app.py +76 -7
README.md CHANGED
@@ -56,8 +56,43 @@ You can use this model directly with a FastAPI endpoint:
56
  ```python
57
  import requests
58
 
 
59
  response = requests.post(
60
- "https://huggingface.co/Jiahuita/NewsSourceClassification",
61
  json={"text": "Your news headline here"}
62
  )
63
- print(response.json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ```python
57
  import requests
58
 
59
+ # Make a prediction
60
  response = requests.post(
61
+ "https://huggingface.co/Jiahuita/NewsSourceClassification/predict",
62
  json={"text": "Your news headline here"}
63
  )
64
+ print(response.json())
65
+ ```
66
+
67
+ Or use it locally:
68
+
69
+ ```python
70
+ from transformers import pipeline
71
+
72
+ classifier = pipeline("text-classification", model="Jiahuita/NewsSourceClassification")
73
+ result = classifier("Your news headline here")
74
+ print(result)
75
+ ```
76
+
77
+ Example response:
78
+ ```json
79
+ {
80
+ "label": "foxnews",
81
+ "score": 0.875
82
+ }
83
+ ```
84
+
85
+ ## Limitations and Bias
86
+
87
+ This model has been trained on news headlines from specific sources and time periods, which may introduce certain biases. Users should be aware of these limitations when using the model.
88
+
89
+ ## Training
90
+
91
+ The model was trained using:
92
+ - TensorFlow 2.13.0
93
+ - LSTM architecture
94
+ - Binary cross-entropy loss
95
+ - Adam optimizer
96
+
97
+ ## License
98
+ This project is licensed under the MIT License.
app.py CHANGED
@@ -1,15 +1,84 @@
1
- from transformers import pipeline
2
- from fastapi import FastAPI
3
  from pydantic import BaseModel
4
-
5
- app = FastAPI()
 
 
 
6
 
7
  class TextInput(BaseModel):
8
  text: str
9
 
10
- classifier = pipeline("text-classification", model="./")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @app.post("/predict")
13
  async def predict(input_data: TextInput):
14
- result = classifier(input_data.text)
15
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
+ from transformers import Pipeline
4
+ import tensorflow as tf
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ import json
7
+ import os
8
 
9
  class TextInput(BaseModel):
10
  text: str
11
 
12
+ app = FastAPI(
13
+ title="News Source Classifier",
14
+ description="A model to classify news headlines as either Fox News or NBC News",
15
+ version="1.0.0"
16
+ )
17
+
18
+ class NewsClassificationPipeline(Pipeline):
19
+ def __init__(self):
20
+ super().__init__()
21
+ model_path = os.path.join(os.path.dirname(__file__), 'news_classifier.h5')
22
+ self.model = tf.keras.models.load_model(model_path)
23
+
24
+ tokenizer_path = os.path.join(os.path.dirname(__file__), 'tokenizer.json')
25
+ with open(tokenizer_path, 'r') as f:
26
+ tokenizer_data = json.load(f)
27
+ self.tokenizer = tf.keras.preprocessing.text.tokenizer_from_json(tokenizer_data)
28
+
29
+ def __call__(self, text):
30
+ if isinstance(text, str):
31
+ text = [text]
32
+
33
+ sequences = self.tokenizer.texts_to_sequences(text)
34
+ padded = pad_sequences(sequences, maxlen=128)
35
+
36
+ predictions = self.model.predict(padded)
37
+
38
+ results = []
39
+ for pred in predictions:
40
+ label = "foxnews" if pred[0] > 0.5 else "nbc"
41
+ score = float(pred[0] if label == "foxnews" else 1 - pred[0])
42
+ results.append({"label": label, "score": score})
43
+
44
+ return results[0] if len(results) == 1 else results
45
+
46
+ try:
47
+ classifier = NewsClassificationPipeline()
48
+ except Exception as e:
49
+ print(f"Error initializing model: {str(e)}")
50
+ raise
51
+
52
+ @app.get("/")
53
+ async def root():
54
+ return {
55
+ "message": "News Source Classification API",
56
+ "usage": "Send POST request to /predict with {'text': 'your news headline'}"
57
+ }
58
 
59
  @app.post("/predict")
60
  async def predict(input_data: TextInput):
61
+ try:
62
+ result = classifier(input_data.text)
63
+ return result
64
+ except Exception as e:
65
+ raise HTTPException(status_code=500, detail=str(e))
66
+
67
+ @app.get("/examples")
68
+ async def examples():
69
+ return {
70
+ "examples": [
71
+ {
72
+ "title": "Crime News Headline",
73
+ "text": "Wife of murdered Minnesota pastor hired 3 men to kill husband after affair: police"
74
+ },
75
+ {
76
+ "title": "Science News Headline",
77
+ "text": "Scientists discover breakthrough in renewable energy research"
78
+ },
79
+ {
80
+ "title": "Political News Headline",
81
+ "text": "Presidential candidates face off in heated debate over climate policies"
82
+ }
83
+ ]
84
+ }