aryadytm's picture
Add application file
2e5179b
from transformers import (
AlbertTokenizerFast, DistilBertTokenizerFast, RobertaTokenizerFast,
AlbertForSequenceClassification, DistilBertForSequenceClassification, RobertaForSequenceClassification
)
import torch.nn.functional as F
import torch
import gradio as gr
from gradio.components import Textbox
albert_tokenizer = AlbertTokenizerFast.from_pretrained('albert-base-v2')
bert_tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
roberta_tokenizer = RobertaTokenizerFast.from_pretrained('distilroberta-base')
albert_model = AlbertForSequenceClassification.from_pretrained("./model_albert")
bert_model = DistilBertForSequenceClassification.from_pretrained("./model_bert")
roberta_model = RobertaForSequenceClassification.from_pretrained("./model_roberta")
def majority_voting(votes):
vote_result = max(set(votes), key=votes.count)
return vote_result
def predict_news(news_title, news_text):
combined_text = f"{news_title} - {news_text}"
albert_input = albert_tokenizer(combined_text, return_tensors='pt', truncation=True, padding=True)
bert_input = bert_tokenizer(combined_text, return_tensors='pt', truncation=True, padding=True)
roberta_input = roberta_tokenizer(combined_text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
albert_logits = albert_model(**albert_input).logits
bert_logits = bert_model(**bert_input).logits
roberta_logits = roberta_model(**roberta_input).logits
albert_probs = F.softmax(albert_logits, dim=-1)
bert_probs = F.softmax(bert_logits, dim=-1)
roberta_probs = F.softmax(roberta_logits, dim=-1)
albert_pred_class = albert_probs.argmax(-1).item()
bert_pred_class = bert_probs.argmax(-1).item()
roberta_pred_class = roberta_probs.argmax(-1).item()
albert_pred_prob = albert_probs[0, albert_pred_class].item()
bert_pred_prob = bert_probs[0, bert_pred_class].item()
roberta_pred_prob = roberta_probs[0, roberta_pred_class].item()
average_confidence = (albert_pred_prob + bert_pred_prob + roberta_pred_prob) / 3
final_pred_class = majority_voting([albert_pred_class, bert_pred_class, roberta_pred_class])
prediction = 'REAL' if final_pred_class == 1 else 'FAKE'
return prediction, f"{average_confidence:.2f}", f"{albert_pred_prob:.2f}", f"{bert_pred_prob:.2f}", f"{roberta_pred_prob:.2f}"
iface = gr.Interface(
fn=predict_news,
inputs=[
Textbox(lines=2, label="News Title", placeholder="Enter News Title Here..."),
Textbox(lines=7, label="News Text", placeholder="Enter News Text Here...")
],
outputs=[
Textbox(label="Prediction"),
Textbox(label="Average Confidence"),
Textbox(label="ALBERT Confidence"),
Textbox(label="BERT Confidence"),
Textbox(label="RoBERTa Confidence")
],
title="Fake News Classification with Ensemble Learning Transformer Models",
description="UAS Deep Learning - 2501985836 - Arya Adyatma"
)
iface.launch()