kmack commited on
Commit
43d7102
1 Parent(s): 9e6277b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -2
app.py CHANGED
@@ -1,3 +1,51 @@
1
- import gradio as gr
 
 
 
2
 
3
- gr.load("models/kmack/malicious-url-detection").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Load model directly
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import numpy as np
4
+ import torch
5
 
6
+ # Check if CUDA is available
7
+ if torch.cuda.is_available():
8
+ # Choose a specific GPU or use the default
9
+ device = torch.device("cuda:0")
10
+ else:
11
+ # Or CPU
12
+ device = torch.device("cpu")
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("kmack/malicious-url-detection")
15
+ model = AutoModelForSequenceClassification.from_pretrained("kmack/malicious-url-detection")
16
+
17
+ # set Model to cude
18
+ model = model.to(device)
19
+
20
+ # predict function
21
+ def get_predit(input_text: str) -> dict:
22
+ label2id = model.config.label2id
23
+ inputs = tokenizer(input_text, return_tensors='pt', truncation=True)
24
+ inputs = inputs.to(device)
25
+ outputs = model(**inputs)
26
+ logits = outputs.logits
27
+ sigmoid = torch.nn.Sigmoid()
28
+ probs = sigmoid(logits.squeeze().cpu())
29
+ probs = probs.detach().numpy()
30
+ for i, k in enumerate(label2id.keys()):
31
+ label2id[k] = probs[i]
32
+ label2id = {k: float(v) for k, v in sorted(label2id.items(), key=lambda item: item[1].item(), reverse=True)}
33
+ return label2id
34
+
35
+ # Define example URLs
36
+ example_url_1 = 'https://medium.com'
37
+ example_url_2 = 'http://google.com-redirect@valimail.com'
38
+ example_url_3 = 'https://a101-nisan-kampanyalari.com'
39
+
40
+ # Create the Gradio interface
41
+ demo = gr.Interface(
42
+ fn=get_predit,
43
+ inputs=gr.components.Textbox(label='Input', placeholder='Enter URL here...'),
44
+ outputs=gr.components.Label(label='Predictions', num_top_classes=5),
45
+ title='kmack/malicious-url-detection',
46
+ description='Detects whether a given URL is benign or potentially malicious.',
47
+ examples=[[example_url_1], [example_url_2], [example_url_3]],
48
+ allow_flagging='never'
49
+ )
50
+
51
+ demo.launch()