munzirmuneer commited on
Commit
4047a9c
1 Parent(s): 991a833

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +26 -0
inference.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from peft import PeftModel
5
+
6
+ # Load model and tokenizer
7
+ model_name = "munzirmuneer/phishing_url_gemma_pytorch" # Replace with your specific model
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
10
+ model = PeftModel.from_pretrained(model, model_name)
11
+
12
+ def predict(input_text):
13
+ # Tokenize input
14
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
15
+
16
+ # Run inference
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+
20
+ # Get logits and probabilities
21
+ logits = outputs.logits
22
+ probs = F.softmax(logits, dim=-1)
23
+
24
+ # Get the predicted class (highest probability)
25
+ pred_class = torch.argmax(probs, dim=-1)
26
+ return pred_class.item(), probs[0].tolist()