khizon commited on
Commit
f640140
·
1 Parent(s): f5669c3

classifier demo

Browse files
Files changed (4) hide show
  1. .gitignore +134 -0
  2. README.md +4 -4
  3. app.py +121 -0
  4. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ data
132
+ artifacts/
133
+ wandb/
134
+ results
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: UnreliableNews
3
- emoji: 📉
4
- colorFrom: yellow
5
  colorTo: pink
6
  sdk: streamlit
7
  app_file: app.py
8
- pinned: false
9
  ---
10
 
11
  # Configuration
 
1
  ---
2
+ title: Unreliable News Classifier
3
+ emoji: 📰
4
+ colorFrom: red
5
  colorTo: pink
6
  sdk: streamlit
7
  app_file: app.py
8
+ pinned: true
9
  ---
10
 
11
  # Configuration
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ import streamlit as st
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
10
+
11
+ @st.cache(allow_output_mutation=True)
12
+ def init_model():
13
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
14
+ model = DistilBertForSequenceClassification.from_pretrained('khizon/distilbert-unreliable-news-eng-4L', num_labels = 2)
15
+
16
+ return tokenizer, model
17
+
18
+ def download_dataset():
19
+ url = 'https://drive.google.com/drive/folders/11mRvsHAkggFEJvG4axH4mmWI6FHMQp7X?usp=sharing'
20
+ data = 'data/nela_gt_2018_site_split'
21
+
22
+ os.system(f'gdown --folder {url} -O {data}')
23
+
24
+ @st.cache(allow_output_mutation=True)
25
+ def jsonl_to_df(file_path):
26
+ with open(file_path) as f:
27
+ lines = f.read().splitlines()
28
+
29
+ df_inter = pd.DataFrame(lines)
30
+ df_inter.columns = ['json_element']
31
+
32
+ df_inter['json_element'].apply(json.loads)
33
+
34
+ return pd.json_normalize(df_inter['json_element'].apply(json.loads))
35
+
36
+ @st.cache
37
+ def load_test_df():
38
+ file_path = os.path.join('data', 'nela_gt_2018_site_split', 'test.jsonl')
39
+ test_df = jsonl_to_df(file_path)
40
+ test_df = pd.get_dummies(test_df, columns = ['label'])
41
+ return test_df
42
+
43
+ @st.cache(allow_output_mutation=True)
44
+ def predict(model, tokenizer, data):
45
+
46
+ labels = data[['label_0', 'label_1']]
47
+ labels = torch.tensor(labels, dtype=torch.float32)
48
+ encoding = tokenizer.encode_plus(
49
+ data['title'],
50
+ ' [SEP] ' + data['content'],
51
+ add_special_tokens=True,
52
+ max_length = 512,
53
+ return_token_type_ids = False,
54
+ padding = 'max_length',
55
+ truncation = 'only_second',
56
+ return_attention_mask = True,
57
+ return_tensors = 'pt'
58
+ )
59
+
60
+ output = model(**encoding)
61
+ return correct_preds(output['logits'], labels)
62
+
63
+ @st.cache(allow_output_mutation=True)
64
+ def predict_new(model, tokenizer, title, content):
65
+ encoding = tokenizer.encode_plus(
66
+ title,
67
+ ' [SEP] ' + content,
68
+ add_special_tokens=True,
69
+ max_length = 512,
70
+ return_token_type_ids = False,
71
+ padding = 'max_length',
72
+ truncation = 'only_second',
73
+ return_attention_mask = True,
74
+ return_tensors = 'pt'
75
+ )
76
+ output = model(**encoding)
77
+ preds = F.softmax(output['logits'], dim = 1)
78
+ p_idx = torch.argmax(preds, dim = 1)
79
+ return 'reliable' if p_idx > 0 else 'unreliable'
80
+
81
+ def correct_preds(preds, labels):
82
+ preds = torch.nn.functional.softmax(preds, dim = 1)
83
+ p_idx = torch.argmax(preds, dim=1)
84
+ l_idx = torch.argmax(labels, dim=0)
85
+
86
+ pred_label = 'reliable' if p_idx > 0 else 'unreliable'
87
+ correct = True if (p_idx == l_idx).sum().item() > 0 else False
88
+ return pred_label, correct
89
+
90
+
91
+ if __name__ == '__main__':
92
+ if not os.path.exists('data/nela_gt_2018_site_split/test.jsonl'):
93
+ download_dataset()
94
+ df = load_test_df()
95
+ tokenizer, model = init_model()
96
+
97
+ st.title("Unreliable News classifier")
98
+ mode = st.radio(
99
+ '', ('Test article', 'Input own article')
100
+ )
101
+ if mode == 'Test article':
102
+ if st.button('Get random article'):
103
+ idx = np.random.randint(0, len(df))
104
+ sample = df.iloc[idx]
105
+
106
+ prediction, correct = predict(model, tokenizer, sample)
107
+ label = 'reliable' if sample['label_1'] > sample['label_0'] else 'unreliable'
108
+ st.header(sample['title'])
109
+ if correct:
110
+ st.success(f'Prediction: {prediction}')
111
+ else:
112
+ st.error(f'Prediction: {prediction}')
113
+ st.caption(f'Source: {sample["source"]} ({label})')
114
+ st.markdown(sample['content'])
115
+ else:
116
+ title = st.text_input('Article title', 'Test title')
117
+ content = st.text_area('Article content', 'Lorem ipsum')
118
+ if st.button('Submit'):
119
+ pred = predict_new(model, tokenizer, title, content)
120
+ st.markdown(f'Prediction: {pred}')
121
+ # st.success('success')
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ -f https://download.pytorch.org/whl/cu113/torch_stable.html
2
+ gdown==4.2.0
3
+ numpy==1.21.4
4
+ pandas==1.3.4
5
+ torch==1.10.1
6
+ transformers==4.13.0