bishalshrestha commited on
Commit
b05b1d8
1 Parent(s): 78c82e0

Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +92 -0
  3. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing necessary libraries
2
+ from datasets import load_dataset
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, TextClassificationPipeline
4
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
5
+ import gradio as gr
6
+
7
+ # Load the dataset
8
+ ds = load_dataset("GonzaloA/fake_news")
9
+
10
+ # Load pre-trained tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
12
+
13
+ # Define tokenization function
14
+ def tokenize_function(examples):
15
+ return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=128)
16
+
17
+ # Apply tokenization
18
+ tokenized_datasets = ds.map(tokenize_function, batched=True)
19
+
20
+ # Load pre-trained BERT model for sequence classification
21
+ model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
22
+
23
+ # Define training arguments
24
+ training_args = TrainingArguments(
25
+ output_dir='./results',
26
+ num_train_epochs=3,
27
+ per_device_train_batch_size=8,
28
+ per_device_eval_batch_size=8,
29
+ evaluation_strategy='epoch',
30
+ logging_dir='./logs',
31
+ )
32
+
33
+ # Create trainer instance
34
+ trainer = Trainer(
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=tokenized_datasets['train'].shuffle().select(range(1000)),
38
+ eval_dataset=tokenized_datasets['test'].shuffle().select(range(1000)),
39
+ )
40
+
41
+ # Start training
42
+ trainer.train()
43
+
44
+ # Define function to compute metrics
45
+ def compute_metrics(pred):
46
+ labels = pred.label_ids
47
+ preds = pred.predictions.argmax(-1)
48
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
49
+ acc = accuracy_score(labels, preds)
50
+ return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}
51
+
52
+ # Update trainer to include custom metrics
53
+ trainer.compute_metrics = compute_metrics
54
+
55
+ # Evaluate the model
56
+ eval_result = trainer.evaluate()
57
+ print(eval_result)
58
+
59
+ # Save the fine-tuned model and tokenizer
60
+ trainer.save_model('TeamQuad-fine-tuned-bert')
61
+ tokenizer.save_pretrained('TeamQuad-fine-tuned-bert')
62
+
63
+ # Load the fine-tuned model and tokenizer
64
+ new_model = AutoModelForSequenceClassification.from_pretrained('TeamQuad-fine-tuned-bert')
65
+ new_tokenizer = AutoTokenizer.from_pretrained('TeamQuad-fine-tuned-bert')
66
+
67
+ # Create a classification pipeline
68
+ classifier = TextClassificationPipeline(model=new_model, tokenizer=new_tokenizer)
69
+
70
+ # Add label mapping for fake news detection (assuming LABEL_0 = 'fake' and LABEL_1 = 'true')
71
+ label_mapping = {0: 'fake', 1: 'true'}
72
+
73
+ # Function to classify input text
74
+ def classify_news(text):
75
+ result = classifier(text)
76
+ # Extract the label and score
77
+ label = result[0]['label'] # 'LABEL_0' or 'LABEL_1'
78
+ score = result[0]['score'] # Confidence score
79
+ mapped_result = {'label': label_mapping[int(label.split('_')[1])], 'score': score}
80
+ return f"Label: {mapped_result['label']}, Score: {mapped_result['score']:.4f}"
81
+
82
+ # Create a Gradio interface
83
+ iface = gr.Interface(
84
+ fn=classify_news, # The function to process the input
85
+ inputs=gr.Textbox(lines=10, placeholder="Enter a news headline or article to classify..."),
86
+ outputs="text", # Output will be displayed as text
87
+ title="Fake News Detection",
88
+ description="Enter a news headline or article and see whether the model classifies it as 'Fake News' or 'True News'.",
89
+ )
90
+
91
+ # Launch the interface
92
+ iface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch