Rifky commited on
Commit
ef01f5b
1 Parent(s): 38b86bb

cleaner code

Browse files
Files changed (1) hide show
  1. app.py +26 -37
app.py CHANGED
@@ -6,10 +6,10 @@ import time
6
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
7
  from Scraper import Scrap
8
 
 
9
 
10
  model_checkpoint = "Rifky/FND"
11
- label = {0: "Fakta", 1: "Hoax"}
12
-
13
 
14
  @st.cache(show_spinner=False, allow_output_mutation=True)
15
  def load_model():
@@ -17,23 +17,21 @@ def load_model():
17
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, fast=True)
18
  return Trainer(model=model), tokenizer
19
 
 
 
20
 
21
- st.write('# Fake News Detection AI')
 
22
 
23
  with st.spinner("Loading Model..."):
24
  model, tokenizer = load_model()
25
 
26
- user_input = st.text_area("Put article url or the full text", help="the text you want to analyze", height=200)
27
- submit = st.button("submit")
28
 
29
- def sigmoid(x):
30
- return 1 / (1 + np.exp(-x))
31
 
32
  if submit:
33
  last_time = time.time()
34
-
35
- text = ""
36
-
37
  with st.spinner("Reading Article..."):
38
  if user_input:
39
  if user_input[:4] == 'http':
@@ -45,33 +43,24 @@ if submit:
45
  text = re.sub(r'\n', ' ', text)
46
 
47
  with st.spinner("Computing..."):
48
- text_len = len(text.split(" "))
49
- if text_len > 512:
50
- texts = []
51
- for i in range(text_len // 512):
52
- texts.append(" ".join(text.split(" ")[i * 512:(i + 1) * 512]))
53
-
54
- texts.append(" ".join(text.split(" ")[(text_len // 512) + 1:text_len % 512]))
55
-
56
- for i in range(len(texts)):
57
- texts[i] = tokenizer(texts[i], max_length=512, truncation=True, padding="max_length")
58
-
59
- results = model.predict(texts)[0]
60
- result = [0, 0]
61
- for i in range(len(results)):
62
- result[0] += sigmoid(results[i][0])
63
- result[1] += sigmoid(results[i][1])
64
-
65
- result[0] /= len(results)
66
- result[1] /= len(results)
67
-
68
- else:
69
- text = tokenizer(text, max_length=512, truncation=True, padding="max_length")
70
- result = model.predict([text])[0][0]
71
 
72
  print (f'\nresult: {result}')
73
-
74
- st.markdown(f"<small>Compute Finished in {int(time.time() - last_time)} seconds</small>", unsafe_allow_html=True)
75
-
76
  prediction = np.argmax(result, axis=-1)
77
- st.success(f"Prediction: {label[prediction]}")
 
 
 
6
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
7
  from Scraper import Scrap
8
 
9
+ st.set_page_config(layout="wide")
10
 
11
  model_checkpoint = "Rifky/FND"
12
+ label = {0: "valid", 1: "fake"}
 
13
 
14
  @st.cache(show_spinner=False, allow_output_mutation=True)
15
  def load_model():
 
17
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, fast=True)
18
  return Trainer(model=model), tokenizer
19
 
20
+ def sigmoid(x):
21
+ return 1 / (1 + np.exp(-x))
22
 
23
+ input_column, reference_column = st.columns(2, gap="medium")
24
+ input_column.write('# Fake News Detection AI')
25
 
26
  with st.spinner("Loading Model..."):
27
  model, tokenizer = load_model()
28
 
29
+ user_input = input_column.text_input("Article url")
30
+ submit = input_column.button("submit")
31
 
 
 
32
 
33
  if submit:
34
  last_time = time.time()
 
 
 
35
  with st.spinner("Reading Article..."):
36
  if user_input:
37
  if user_input[:4] == 'http':
 
43
  text = re.sub(r'\n', ' ', text)
44
 
45
  with st.spinner("Computing..."):
46
+ text = text.split()
47
+ text_len = len(text)
48
+
49
+ sequences = []
50
+ for i in range(text_len // 512):
51
+ sequences.append(" ".join(text[i * 512: (i + 1) * 512]))
52
+ sequences.append(" ".join(text[text_len - (text_len % 512) : text_len]))
53
+ sequences = [tokenizer(i, max_length=512, truncation=True, padding="max_length") for i in sequences]
54
+
55
+ predictions = model.predict(sequences)[0]
56
+ result = [
57
+ np.sum([sigmoid(i[0]) for i in predictions]) / len(predictions),
58
+ np.sum([sigmoid(i[1]) for i in predictions]) / len(predictions)
59
+ ]
 
 
 
 
 
 
 
 
 
60
 
61
  print (f'\nresult: {result}')
62
+ input_column.markdown(f"<small>Compute Finished in {int(time.time() - last_time)} seconds</small>", unsafe_allow_html=True)
 
 
63
  prediction = np.argmax(result, axis=-1)
64
+ input_column.success(f"This news is {label[prediction]}.")
65
+ st.text(f"{int(result[prediction]*100)}% confidence")
66
+ input_column.progress(result[prediction])