Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
from textblob import TextBlob | |
from transformers import BertForSequenceClassification, AdamW, BertConfig | |
st.set_page_config(layout='wide', initial_sidebar_state='expanded') | |
col1, col2= st.columns(2) | |
with col2: | |
text = st.text_input("Enter the text you'd like to analyze for spam.") | |
aButton = st.button('Analyze') | |
with col1: | |
st.title("Spam Detector") | |
import torch | |
import numpy as np | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-uncased") | |
from transformers import AutoModel | |
model = BertForSequenceClassification.from_pretrained("NimaKL/spamd_model") | |
token_id = [] | |
attention_masks = [] | |
def preprocessing(input_text, tokenizer): | |
''' | |
Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields: | |
- input_ids: list of token ids | |
- token_type_ids: list of token type ids | |
- attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True). | |
''' | |
return tokenizer.encode_plus( | |
input_text, | |
add_special_tokens = True, | |
max_length = 32, | |
pad_to_max_length = True, | |
return_attention_mask = True, | |
return_tensors = 'pt' | |
) | |
device = 'cpu' | |
def predict(new_sentence): | |
# We need Token IDs and Attention Mask for inference on the new sentence | |
test_ids = [] | |
test_attention_mask = [] | |
# Apply the tokenizer | |
encoding = preprocessing(new_sentence, tokenizer) | |
# Extract IDs and Attention Mask | |
test_ids.append(encoding['input_ids']) | |
test_attention_mask.append(encoding['attention_mask']) | |
test_ids = torch.cat(test_ids, dim = 0) | |
test_attention_mask = torch.cat(test_attention_mask, dim = 0) | |
# Forward pass, calculate logit predictions | |
with torch.no_grad(): | |
output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device)) | |
prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal' | |
pred = 'Predicted Class: '+ prediction | |
return pred | |
if text or aButton: | |
with col2: | |
with st.spinner('Wait for it...'): | |
st.success(predict(text)) |