|
|
import os
|
|
|
import torch
|
|
|
from transformers import AutoTokenizer, T5ForSequenceClassification
|
|
|
from typing import Dict, List, Any
|
|
|
|
|
|
class EndpointHandler:
|
|
|
"""
|
|
|
HuggingFace Inference Endpoint Handler for Java Vulnerability Detection
|
|
|
CodeT5 ๊ธฐ๋ฐ ๋ถ๋ฅ ๋ชจ๋ธ (LoRA fine-tuned)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, path="."):
|
|
|
"""
|
|
|
๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
|
|
|
|
|
|
Args:
|
|
|
path (str): ๋ชจ๋ธ์ด ์ ์ฅ๋ ๊ฒฝ๋ก (HuggingFace Hub์์ ์๋์ผ๋ก ์ค์ ๋จ)
|
|
|
"""
|
|
|
print(f"๐ Loading Java Vulnerability Detection Model from {path}")
|
|
|
|
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
print(f"๐ Device: {self.device}")
|
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
|
|
|
|
|
|
|
self.model = T5ForSequenceClassification.from_pretrained(
|
|
|
path,
|
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
|
|
)
|
|
|
|
|
|
|
|
|
self.model.to(self.device)
|
|
|
self.model.eval()
|
|
|
|
|
|
print("โ
Model loaded successfully!")
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
๋ฉ์ธ ์ถ๋ก ๋ฉ์๋ (HuggingFace Inference API๊ฐ ํธ์ถ)
|
|
|
|
|
|
Args:
|
|
|
data (dict): ์
๋ ฅ ๋ฐ์ดํฐ
|
|
|
- "inputs" (str): Java ์ฝ๋ ๋๋
|
|
|
- "code" (str): Java ์ฝ๋
|
|
|
|
|
|
Returns:
|
|
|
list: ์์ธก ๊ฒฐ๊ณผ ๋ฆฌ์คํธ
|
|
|
"""
|
|
|
|
|
|
inputs = self.preprocess(data)
|
|
|
|
|
|
|
|
|
outputs = self.inference(inputs)
|
|
|
|
|
|
|
|
|
result = self.postprocess(outputs)
|
|
|
|
|
|
return result
|
|
|
|
|
|
def preprocess(self, request: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
์
๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํฉ๋๋ค.
|
|
|
|
|
|
Args:
|
|
|
request (dict): API ์์ฒญ ๋ฐ์ดํฐ
|
|
|
|
|
|
Returns:
|
|
|
dict: ํ ํฌ๋์ด์ฆ๋ ์
๋ ฅ ํ
์
|
|
|
"""
|
|
|
|
|
|
if isinstance(request, dict):
|
|
|
|
|
|
code = request.get("inputs") or request.get("code")
|
|
|
elif isinstance(request, list) and len(request) > 0:
|
|
|
code = request[0].get("inputs") or request[0].get("code")
|
|
|
elif isinstance(request, str):
|
|
|
code = request
|
|
|
else:
|
|
|
raise ValueError(
|
|
|
"Invalid request format. Expected {'inputs': 'Java code here'} "
|
|
|
"or {'code': 'Java code here'}"
|
|
|
)
|
|
|
|
|
|
if not code:
|
|
|
raise ValueError("No code provided in request")
|
|
|
|
|
|
|
|
|
input_text = f"Is this Java code vulnerable?:\n{code}"
|
|
|
|
|
|
|
|
|
inputs = self.tokenizer(
|
|
|
input_text,
|
|
|
max_length=512,
|
|
|
truncation=True,
|
|
|
padding="max_length",
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
def inference(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
|
"""
|
|
|
๋ชจ๋ธ ์ถ๋ก ์ ์ํํฉ๋๋ค.
|
|
|
|
|
|
Args:
|
|
|
inputs (dict): ์ ์ฒ๋ฆฌ๋ ์
๋ ฅ ํ
์
|
|
|
|
|
|
Returns:
|
|
|
torch.Tensor: ๋ชจ๋ธ ์ถ๋ ฅ ๋ก์ง
|
|
|
"""
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(**inputs)
|
|
|
logits = outputs.logits
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def postprocess(self, logits: torch.Tensor) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
๋ชจ๋ธ ์ถ๋ ฅ์ ์ฌ๋์ด ์ฝ์ ์ ์๋ ํํ๋ก ๋ณํํฉ๋๋ค.
|
|
|
|
|
|
Args:
|
|
|
logits (torch.Tensor): ๋ชจ๋ธ ์ถ๋ ฅ ๋ก์ง
|
|
|
|
|
|
Returns:
|
|
|
list: ์์ธก ๊ฒฐ๊ณผ ๋ฆฌ์คํธ
|
|
|
"""
|
|
|
|
|
|
if logits.shape[-1] == 1:
|
|
|
|
|
|
prob = torch.sigmoid(logits).item()
|
|
|
predicted_class = 1 if prob > 0.5 else 0
|
|
|
confidence = prob if predicted_class == 1 else (1 - prob)
|
|
|
probabilities = {
|
|
|
"LABEL_0": 1 - prob,
|
|
|
"LABEL_1": prob
|
|
|
}
|
|
|
else:
|
|
|
|
|
|
probs = torch.softmax(logits, dim=1)[0]
|
|
|
predicted_class = torch.argmax(logits, dim=1).item()
|
|
|
confidence = probs[predicted_class].item()
|
|
|
probabilities = {
|
|
|
f"LABEL_{i}": probs[i].item()
|
|
|
for i in range(len(probs))
|
|
|
}
|
|
|
|
|
|
|
|
|
label_map = {
|
|
|
0: "safe",
|
|
|
1: "vulnerable"
|
|
|
}
|
|
|
|
|
|
|
|
|
result = {
|
|
|
"label": label_map.get(predicted_class, f"LABEL_{predicted_class}"),
|
|
|
"score": confidence,
|
|
|
"probabilities": probabilities,
|
|
|
"details": {
|
|
|
"is_vulnerable": predicted_class == 1,
|
|
|
"confidence_percentage": f"{confidence * 100:.2f}%",
|
|
|
"safe_probability": probabilities.get("LABEL_0", 0),
|
|
|
"vulnerable_probability": probabilities.get("LABEL_1", 0)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
return [result]
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
handler = EndpointHandler(path=".")
|
|
|
|
|
|
|
|
|
test_code = """
|
|
|
import java.sql.*;
|
|
|
public class SQLInjectionVulnerable {
|
|
|
public void getUser(String userInput) {
|
|
|
String query = "SELECT * FROM users WHERE username = '" + userInput + "'";
|
|
|
Statement statement = connection.createStatement();
|
|
|
ResultSet resultSet = statement.executeQuery(query);
|
|
|
}
|
|
|
}
|
|
|
"""
|
|
|
|
|
|
|
|
|
request = {"inputs": test_code}
|
|
|
result = handler(request)
|
|
|
|
|
|
print("\n๐ Test Result:")
|
|
|
print(result) |