FND / app.py
Rifky's picture
Update app.py
4d5820d
raw
history blame
2.4 kB
import streamlit as st
import numpy as np
import pandas as pd
import re
import time
import os
from transformers import AutoModelForSequenceClassification, AutoModel, AutoTokenizer
from Scraper import Scrap
st.set_page_config(layout="wide")
model_checkpoint = "Rifky/FND"
data_checkpoint = "Rifky/turnbackhoax-encoded"
label = {0: "valid", 1: "fake"}
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_model():
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, fast=True)
return model, tokenizer
def sigmoid(x):
return 1 / (1 + np.exp(-x))
input_column, reference_column = st.columns(2)
input_column.write('# Fake News Detection AI')
with st.spinner("Loading Model..."):
model, tokenizer = load_model()
user_input = input_column.text_input("Article url")
submit = input_column.button("submit")
if submit:
last_time = time.time()
with st.spinner("Reading Article..."):
if user_input:
if user_input[:4] == 'http':
text = Scrap(user_input)
else:
text = user_input
if text:
text = re.sub(r'\n', ' ', text)
with st.spinner("Computing..."):
token = text.split()
text_len = len(token)
sequences = []
for i in range(text_len // 512):
sequences.append(" ".join(token[i * 512: (i + 1) * 512]))
sequences.append(" ".join(token[text_len - (text_len % 512) : text_len]))
sequences = tokenizer(sequences, max_length=512, truncation=True, padding="max_length", return_tensors='pt')
predictions = model(**sequences)[0].detach().numpy()
result = [
np.sum([sigmoid(i[0]) for i in predictions]) / len(predictions),
np.sum([sigmoid(i[1]) for i in predictions]) / len(predictions)
]
print (f'\nresult: {result}')
input_column.markdown(f"<small>Compute Finished in {int(time.time() - last_time)} seconds</small>", unsafe_allow_html=True)
prediction = np.argmax(result, axis=-1)
input_column.success(f"This news is {label[prediction]}.")
st.text(f"{int(result[prediction]*100)}% confidence")
input_column.progress(result[prediction])