bgspaditya commited on
Commit
c84686e
1 Parent(s): 2d2eb5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -2,19 +2,25 @@ import gradio as gr
2
  from transformers import pipeline, set_seed
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
 
5
  set_seed(42)
6
- num_labels=2
7
- id2label = {0:'benign',1:'phishing'}
8
- label2id = {'benign':0,'phishing':1}
9
  checkpoint = 'bgspaditya/distilbert-phish'
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True, force_download=True)
11
  model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id, force_download=True)
12
 
 
13
  def predict(url):
14
  url_classifier = pipeline(task='text-classification', model=model, tokenizer=tokenizer)
15
  result = url_classifier(url)
16
- return {'label': result[0]['label'], 'score': result[0]['score']}
 
17
 
 
18
  gradio_app = gr.Interface(
19
  predict,
20
  inputs=gr.Textbox(label="Enter URL"),
@@ -22,5 +28,6 @@ gradio_app = gr.Interface(
22
  title="Phishing URL Detection",
23
  )
24
 
 
25
  if __name__ == "__main__":
26
  gradio_app.launch()
 
2
  from transformers import pipeline, set_seed
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Set seed and define model parameters
6
  set_seed(42)
7
+ num_labels = 2
8
+ id2label = {0: 'benign', 1: 'phishing'}
9
+ label2id = {'benign': 0, 'phishing': 1}
10
  checkpoint = 'bgspaditya/distilbert-phish'
11
+
12
+ # Load tokenizer and model
13
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True, force_download=True)
14
  model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels, id2label=id2label, label2id=label2id, force_download=True)
15
 
16
+ # Define predict function
17
  def predict(url):
18
  url_classifier = pipeline(task='text-classification', model=model, tokenizer=tokenizer)
19
  result = url_classifier(url)
20
+ predicted_label = result[0]['label']
21
+ return predicted_label
22
 
23
+ # Define Gradio interface
24
  gradio_app = gr.Interface(
25
  predict,
26
  inputs=gr.Textbox(label="Enter URL"),
 
28
  title="Phishing URL Detection",
29
  )
30
 
31
+ # Launch the Gradio interface
32
  if __name__ == "__main__":
33
  gradio_app.launch()