|
|
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, TextClassificationPipeline |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support |
|
import gradio as gr |
|
|
|
|
|
ds = load_dataset("GonzaloA/fake_news") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') |
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128) |
|
|
|
|
|
tokenized_datasets = ds.map(tokenize_function, batched=True) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=3, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=8, |
|
evaluation_strategy='epoch', |
|
logging_dir='./logs', |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_datasets['train'].shuffle().select(range(1000)), |
|
eval_dataset=tokenized_datasets['test'].shuffle().select(range(1000)), |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
def compute_metrics(pred): |
|
labels = pred.label_ids |
|
preds = pred.predictions.argmax(-1) |
|
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') |
|
acc = accuracy_score(labels, preds) |
|
return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall} |
|
|
|
|
|
trainer.compute_metrics = compute_metrics |
|
|
|
|
|
eval_result = trainer.evaluate() |
|
print(eval_result) |
|
|
|
|
|
trainer.save_model('TeamQuad-fine-tuned-bert') |
|
tokenizer.save_pretrained('TeamQuad-fine-tuned-bert') |
|
|
|
|
|
new_model = AutoModelForSequenceClassification.from_pretrained('TeamQuad-fine-tuned-bert') |
|
new_tokenizer = AutoTokenizer.from_pretrained('TeamQuad-fine-tuned-bert') |
|
|
|
|
|
classifier = TextClassificationPipeline(model=new_model, tokenizer=new_tokenizer) |
|
|
|
|
|
label_mapping = {0: 'fake', 1: 'true'} |
|
|
|
|
|
def classify_news(text): |
|
result = classifier(text) |
|
|
|
label = result[0]['label'] |
|
score = result[0]['score'] |
|
mapped_result = {'label': label_mapping[int(label.split('_')[1])], 'score': score} |
|
return f"Label: {mapped_result['label']}, Score: {mapped_result['score']:.4f}" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=classify_news, |
|
inputs=gr.Textbox(lines=10, placeholder="Enter a news headline or article to classify..."), |
|
outputs="text", |
|
title="Fake News Detection", |
|
description="Enter a news headline or article and see whether the model classifies it as 'Fake News' or 'True News'.", |
|
) |
|
|
|
|
|
iface.launch(share=True) |
|
|