spamd / app.py
NimaKL's picture
Update app.py
d307ad4
raw
history blame contribute delete
No virus
2.82 kB
import streamlit as st
from transformers import pipeline
from textblob import TextBlob
from transformers import BertForSequenceClassification, AdamW, BertConfig
st.set_page_config(layout='wide', initial_sidebar_state='expanded')
col1, col2= st.columns(2)
with col2:
text = st.text_input("Enter the text you'd like to analyze for spam.")
aButton = st.button('Analyze')
with col1:
st.title("Spamd: Turkish Spam Detector")
st.markdown("Message spam detection tool for Turkish language. Due the small size of the dataset, I decided to go with transformers technology Google BERT. Using the Turkish pre-trained model BERTurk, I imporved the accuracy of the tool by 18 percent compared to the previous model which used fastText.")
st.markdown("Original file is located at")
st.markdown("https://colab.research.google.com/drive/1QuorqAuLsmomesZHsaQHEZgzbPEM8YTH")
import torch
import numpy as np
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-uncased")
from transformers import AutoModel
model = BertForSequenceClassification.from_pretrained("NimaKL/spamd_model")
token_id = []
attention_masks = []
def preprocessing(input_text, tokenizer):
'''
Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
- input_ids: list of token ids
- token_type_ids: list of token type ids
- attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
'''
return tokenizer.encode_plus(
input_text,
add_special_tokens = True,
max_length = 32,
pad_to_max_length = True,
return_attention_mask = True,
return_tensors = 'pt'
)
device = 'cpu'
def predict(new_sentence):
# We need Token IDs and Attention Mask for inference on the new sentence
test_ids = []
test_attention_mask = []
# Apply the tokenizer
encoding = preprocessing(new_sentence, tokenizer)
# Extract IDs and Attention Mask
test_ids.append(encoding['input_ids'])
test_attention_mask.append(encoding['attention_mask'])
test_ids = torch.cat(test_ids, dim = 0)
test_attention_mask = torch.cat(test_attention_mask, dim = 0)
# Forward pass, calculate logit predictions
with torch.no_grad():
output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal'
pred = 'Predicted Class: '+ prediction
return pred
if text or aButton:
with col2:
with st.spinner('Wait for it...'):
st.success(predict(text))