Spaces:
Running
Running
ajaxwin commited on
Commit Β·
cf983b8
1
Parent(s): e8c9acc
task2 reviewed, semanticmatcher implemented
Browse files- .gitignore +1 -0
- data/data_loader.py +1 -43
- data/properties.csv +0 -0
- env/schemas.py +1 -1
- inference.py +1 -1
- requirements.txt +5 -0
- tasks/task1/grader.py +3 -3
- tasks/task2/actions.py +24 -59
- tasks/task2/environment.py +9 -9
- tasks/task2/grader.py +12 -142
- utils/__init__.py +4 -2
- utils/matcher.py +0 -29
- utils/propertyretriever.py +80 -0
- utils/semanticmatcher.py +230 -0
.gitignore
CHANGED
|
@@ -11,3 +11,4 @@ build/
|
|
| 11 |
baseline_scores.json
|
| 12 |
*.log
|
| 13 |
.pytest_cache/
|
|
|
|
|
|
| 11 |
baseline_scores.json
|
| 12 |
*.log
|
| 13 |
.pytest_cache/
|
| 14 |
+
MySolution.md
|
data/data_loader.py
CHANGED
|
@@ -110,7 +110,7 @@ def get_all_property_entries(
|
|
| 110 |
entries = []
|
| 111 |
for contract in contracts:
|
| 112 |
for fn in contract.get("functions", []):
|
| 113 |
-
if fn.get("property") is not None:
|
| 114 |
entries.append((contract, fn))
|
| 115 |
return entries
|
| 116 |
|
|
@@ -155,48 +155,6 @@ def get_related_functions(
|
|
| 155 |
|
| 156 |
return sorted(related)
|
| 157 |
|
| 158 |
-
|
| 159 |
-
# ! Function is completely wrong
|
| 160 |
-
|
| 161 |
-
def get_similar_rule(
|
| 162 |
-
contracts: List[Dict[str, Any]],
|
| 163 |
-
current_contract_name: str,
|
| 164 |
-
current_function_name: str,
|
| 165 |
-
) -> Optional[Dict[str, Any]]:
|
| 166 |
-
"""
|
| 167 |
-
Returns the similar_rule hint stored in the target function's property field,
|
| 168 |
-
enriched with the referenced function's natspec if available.
|
| 169 |
-
|
| 170 |
-
Returns a dict with keys: contract_name, function_name, property_hint, natspec.
|
| 171 |
-
Returns None if no similar_rule is defined.
|
| 172 |
-
"""
|
| 173 |
-
# Find target function
|
| 174 |
-
for contract in contracts:
|
| 175 |
-
if contract["contract_name"] == current_contract_name:
|
| 176 |
-
fn = get_function_by_name(contract, current_function_name)
|
| 177 |
-
if fn and fn.get("property") and fn["property"].get("similar_rule"): # ! There is no property or similar_rule field
|
| 178 |
-
sr = fn["property"]["similar_rule"]
|
| 179 |
-
# Look up the referenced function's natspec
|
| 180 |
-
for c2 in contracts:
|
| 181 |
-
if c2["contract_name"] == sr["contract_name"]:
|
| 182 |
-
ref_fn = get_function_by_name(c2, sr["function_name"])
|
| 183 |
-
if ref_fn:
|
| 184 |
-
return {
|
| 185 |
-
"contract_name": sr["contract_name"],
|
| 186 |
-
"function_name": sr["function_name"],
|
| 187 |
-
"property_hint": sr["property_hint"],
|
| 188 |
-
"natspec": ref_fn.get("natspec", ""),
|
| 189 |
-
}
|
| 190 |
-
# Referenced function not found β return hint only
|
| 191 |
-
return {
|
| 192 |
-
"contract_name": sr["contract_name"],
|
| 193 |
-
"function_name": sr["function_name"],
|
| 194 |
-
"property_hint": sr["property_hint"],
|
| 195 |
-
"natspec": "",
|
| 196 |
-
}
|
| 197 |
-
return None
|
| 198 |
-
|
| 199 |
-
|
| 200 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 201 |
# Task 3 helpers
|
| 202 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 110 |
entries = []
|
| 111 |
for contract in contracts:
|
| 112 |
for fn in contract.get("functions", []):
|
| 113 |
+
if fn.get("property", None) is not None and fn.get("vulnerable", False) is False:
|
| 114 |
entries.append((contract, fn))
|
| 115 |
return entries
|
| 116 |
|
|
|
|
| 155 |
|
| 156 |
return sorted(related)
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 159 |
# Task 3 helpers
|
| 160 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
data/properties.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
env/schemas.py
CHANGED
|
@@ -38,7 +38,7 @@ class ActionType(str, Enum):
|
|
| 38 |
GET_FILE_NATSPEC = "get_file_natspec" # -0.03
|
| 39 |
GET_FUNCTION_NATSPEC = "get_function_natspec" # -0.08
|
| 40 |
GET_RELATED_FUNCTIONS = "get_related_functions" # -0.06
|
| 41 |
-
|
| 42 |
SUBMIT_PROPERTY = "submit_property" # scored 0β5, one attempt
|
| 43 |
|
| 44 |
# ββ Task 3 β Rule Checker ββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 38 |
GET_FILE_NATSPEC = "get_file_natspec" # -0.03
|
| 39 |
GET_FUNCTION_NATSPEC = "get_function_natspec" # -0.08
|
| 40 |
GET_RELATED_FUNCTIONS = "get_related_functions" # -0.06
|
| 41 |
+
GET_SIGNATURE = "get_signature" # -0.04
|
| 42 |
SUBMIT_PROPERTY = "submit_property" # scored 0β5, one attempt
|
| 43 |
|
| 44 |
# ββ Task 3 β Rule Checker ββββββββββββββββββββββββββββββββββββββββββββββββ
|
inference.py
CHANGED
|
@@ -32,7 +32,7 @@ from tasks.task1.environment import Task1Environment
|
|
| 32 |
from tasks.task2.environment import Task2Environment
|
| 33 |
from tasks.task3.environment import Task3Environment
|
| 34 |
from env.schemas import Action, ActionType
|
| 35 |
-
from utils
|
| 36 |
|
| 37 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
# Configuration
|
|
|
|
| 32 |
from tasks.task2.environment import Task2Environment
|
| 33 |
from tasks.task3.environment import Task3Environment
|
| 34 |
from env.schemas import Action, ActionType
|
| 35 |
+
from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
|
| 36 |
|
| 37 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
# Configuration
|
requirements.txt
CHANGED
|
@@ -5,3 +5,8 @@ openai==1.51.0
|
|
| 5 |
httpx==0.27.2
|
| 6 |
python-multipart==0.0.9
|
| 7 |
pyyaml==6.0.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
httpx==0.27.2
|
| 6 |
python-multipart==0.0.9
|
| 7 |
pyyaml==6.0.2
|
| 8 |
+
pandas==2.2.2
|
| 9 |
+
numpy==2.1.1
|
| 10 |
+
scikit-learn==1.5.0
|
| 11 |
+
sentence-transformers==3.0.1
|
| 12 |
+
nltk==3.9.4
|
tasks/task1/grader.py
CHANGED
|
@@ -8,15 +8,15 @@ Deterministic grader. Score range: 0.0 β 1.0
|
|
| 8 |
0.0 β wrong function name
|
| 9 |
"""
|
| 10 |
from __future__ import annotations
|
| 11 |
-
from typing import Dict
|
| 12 |
-
from utils
|
| 13 |
from data.data_loader import load_vulnerabilities
|
| 14 |
|
| 15 |
def match_vuln_keywords(submitted: str, expected: str) -> bool:
|
| 16 |
"""Checks if the submitted vulnerability type matches the expected one using keyword matching."""
|
| 17 |
for types in load_vulnerabilities():
|
| 18 |
if types["vulnerability"] == expected:
|
| 19 |
-
return
|
| 20 |
return False
|
| 21 |
|
| 22 |
class Task1Grader:
|
|
|
|
| 8 |
0.0 β wrong function name
|
| 9 |
"""
|
| 10 |
from __future__ import annotations
|
| 11 |
+
from typing import Dict
|
| 12 |
+
from utils import SemanticMatcher
|
| 13 |
from data.data_loader import load_vulnerabilities
|
| 14 |
|
| 15 |
def match_vuln_keywords(submitted: str, expected: str) -> bool:
|
| 16 |
"""Checks if the submitted vulnerability type matches the expected one using keyword matching."""
|
| 17 |
for types in load_vulnerabilities():
|
| 18 |
if types["vulnerability"] == expected:
|
| 19 |
+
return SemanticMatcher().match(types["terms"], submitted)
|
| 20 |
return False
|
| 21 |
|
| 22 |
class Task1Grader:
|
tasks/task2/actions.py
CHANGED
|
@@ -1,24 +1,34 @@
|
|
| 1 |
"""
|
| 2 |
Actions for Task 2: Property Inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
from typing import Any, Dict, Tuple
|
| 6 |
-
from data.data_loader import get_function_by_name, get_related_functions
|
|
|
|
| 7 |
from env.schemas import ActionType, Reward
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 10 |
"""Handle GET_FUNCTION_CODE action."""
|
| 11 |
if ctx._is_repeated(qkey):
|
| 12 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 13 |
|
| 14 |
fn = ctx._target_fn
|
| 15 |
-
name = fn["name"]
|
| 16 |
code = fn.get("code", "// no code available")
|
| 17 |
return (
|
| 18 |
-
|
| 19 |
Reward(value=-0.06, reason="get_function_code cost"),
|
| 20 |
)
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
def get_function_natspec(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 24 |
"""Handle GET_FUNCTION_NATSPEC action."""
|
|
@@ -69,56 +79,28 @@ def get_related_functions_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str
|
|
| 69 |
return text, Reward(value=-0.06, reason="get_related_functions cost")
|
| 70 |
|
| 71 |
|
| 72 |
-
def
|
| 73 |
-
"""Handle
|
| 74 |
if ctx._is_repeated(qkey):
|
| 75 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 76 |
|
| 77 |
fn = ctx._target_fn
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
returns = fn.get("returns", "") or "void"
|
| 81 |
-
out_prop = fn.get("output_property", "")
|
| 82 |
-
visibility = fn.get("visibility", "")
|
| 83 |
-
modifiers = fn.get("modifiers", [])
|
| 84 |
-
|
| 85 |
-
lines = [f"Function: {fn.get('signature', name)}"]
|
| 86 |
-
lines.append(f"Visibility: {visibility}" + (f" Modifiers: {', '.join(modifiers)}" if modifiers else ""))
|
| 87 |
-
if params_list:
|
| 88 |
-
lines.append("Parameters:")
|
| 89 |
-
for p in params_list:
|
| 90 |
-
lines.append(f" β’ {p['type']} {p['name']}: {p.get('description','')}")
|
| 91 |
-
else:
|
| 92 |
-
lines.append("Parameters: none" + (" (payable)" if "payable" in fn.get("code", "") else ""))
|
| 93 |
-
lines.append(f"Returns: {returns}")
|
| 94 |
-
if out_prop:
|
| 95 |
-
lines.append(f"Expected behaviour: {out_prop}")
|
| 96 |
-
return "\n".join(lines), Reward(value=-0.04, reason="get_io cost")
|
| 97 |
|
| 98 |
-
# ! Wrong Function, there is no similar_rule field in the dataset. This function will always return "No similar rule available for this function."
|
| 99 |
|
| 100 |
def get_similar_rule_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 101 |
"""Handle GET_SIMILAR_RULE action."""
|
| 102 |
if ctx._is_repeated(qkey):
|
| 103 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
ctx._contract["contract_name"],
|
| 108 |
-
ctx._target_fn["name"],
|
| 109 |
-
)
|
| 110 |
-
if sr is None:
|
| 111 |
return (
|
| 112 |
"No similar rule available for this function.",
|
| 113 |
Reward(value=-0.20, reason="get_similar_rule cost (not found)"),
|
| 114 |
)
|
| 115 |
-
|
| 116 |
-
f"Similar property from {sr['contract_name']}.{sr['function_name']}():",
|
| 117 |
-
f" {sr['property_hint']}",
|
| 118 |
-
]
|
| 119 |
-
if sr.get("natspec"):
|
| 120 |
-
lines.append(f"\nFunction NatSpec:\n {sr['natspec']}")
|
| 121 |
-
return "\n".join(lines), Reward(value=-0.20, reason="get_similar_rule cost")
|
| 122 |
|
| 123 |
|
| 124 |
def submit_property(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
@@ -129,6 +111,7 @@ def submit_property(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
| 129 |
"Only one submission is allowed.",
|
| 130 |
Reward(value=-1.0, reason="Second submit_property attempt", partial=False),
|
| 131 |
)
|
|
|
|
| 132 |
submitted_text = params.get("property", "").strip()
|
| 133 |
if not submitted_text:
|
| 134 |
return (
|
|
@@ -138,28 +121,10 @@ def submit_property(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
| 138 |
|
| 139 |
ctx._submitted = True
|
| 140 |
ctx._done = True
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
reward = ctx._grader.reward_for_score(score)
|
| 144 |
-
bd = ctx._grader.breakdown(submitted_text)
|
| 145 |
-
|
| 146 |
-
pct = int(score * 100)
|
| 147 |
-
if score >= 0.85:
|
| 148 |
-
emoji, label = "β
", "EXCELLENT"
|
| 149 |
-
elif score >= 0.60:
|
| 150 |
-
emoji, label = "π‘", "GOOD"
|
| 151 |
-
elif score >= 0.35:
|
| 152 |
-
emoji, label = "π ", "PARTIAL"
|
| 153 |
-
else:
|
| 154 |
-
emoji, label = "β", "POOR"
|
| 155 |
-
|
| 156 |
-
msg = (
|
| 157 |
-
f"{emoji} {label} β Score: {score:.2f}/1.00 β Reward: {reward:.2f}/5.00 ({pct}%)\n"
|
| 158 |
-
f"Key concepts matched : {len(bd['key_matched'])}/{len(bd['key_matched'])+len(bd['key_missed'])} "
|
| 159 |
-
f"{bd['key_matched']}\n"
|
| 160 |
-
f"Bonus concepts matched : {len(bd['bonus_matched'])}/{len(bd['bonus_matched'])+len(bd['bonus_missed'])} "
|
| 161 |
-
f"{bd['bonus_matched']}"
|
| 162 |
-
)
|
| 163 |
return msg, Reward(
|
| 164 |
value=reward,
|
| 165 |
reason=f"Property submission score={score:.3f}",
|
|
|
|
| 1 |
"""
|
| 2 |
Actions for Task 2: Property Inference.
|
| 3 |
+
Defines the logic for each action type that the agent can take in the environment, including
|
| 4 |
+
querying function code, NatSpec, related functions, and submitting the inferred property.
|
| 5 |
+
ctx: context object containing episode state and data
|
| 6 |
+
qkey: a unique key representing the specific query (used for tracking repeated queries)
|
| 7 |
+
params: parameters for the action, such as the submitted property text for SUBMIT_PROPERTY
|
| 8 |
"""
|
| 9 |
|
| 10 |
from typing import Any, Dict, Tuple
|
| 11 |
+
from data.data_loader import get_function_by_name, get_related_functions
|
| 12 |
+
from utils import PropertyRetriever
|
| 13 |
from env.schemas import ActionType, Reward
|
| 14 |
|
| 15 |
+
PropertyRetrieverInstance = PropertyRetriever() # Load once at module level
|
| 16 |
+
|
| 17 |
+
# TODO: Can separate into signature, visiblity
|
| 18 |
+
|
| 19 |
def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 20 |
"""Handle GET_FUNCTION_CODE action."""
|
| 21 |
if ctx._is_repeated(qkey):
|
| 22 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 23 |
|
| 24 |
fn = ctx._target_fn
|
|
|
|
| 25 |
code = fn.get("code", "// no code available")
|
| 26 |
return (
|
| 27 |
+
code,
|
| 28 |
Reward(value=-0.06, reason="get_function_code cost"),
|
| 29 |
)
|
| 30 |
|
| 31 |
+
# TODO: Can separate comment and output_property(output_comment)
|
| 32 |
|
| 33 |
def get_function_natspec(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 34 |
"""Handle GET_FUNCTION_NATSPEC action."""
|
|
|
|
| 79 |
return text, Reward(value=-0.06, reason="get_related_functions cost")
|
| 80 |
|
| 81 |
|
| 82 |
+
def get_signature(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 83 |
+
"""Handle GET_SIGNATURE action."""
|
| 84 |
if ctx._is_repeated(qkey):
|
| 85 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 86 |
|
| 87 |
fn = ctx._target_fn
|
| 88 |
+
sig = fn.get("signature")
|
| 89 |
+
return sig, Reward(value=-0.04, reason="get_signature cost")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
|
|
|
| 91 |
|
| 92 |
def get_similar_rule_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
| 93 |
"""Handle GET_SIMILAR_RULE action."""
|
| 94 |
if ctx._is_repeated(qkey):
|
| 95 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 96 |
|
| 97 |
+
similar_rule = PropertyRetrieverInstance.get_similar_property(ctx._target_fn["code"])
|
| 98 |
+
if similar_rule is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
return (
|
| 100 |
"No similar rule available for this function.",
|
| 101 |
Reward(value=-0.20, reason="get_similar_rule cost (not found)"),
|
| 102 |
)
|
| 103 |
+
return similar_rule, Reward(value=-0.20, reason="get_similar_rule cost")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
def submit_property(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
|
|
|
|
| 111 |
"Only one submission is allowed.",
|
| 112 |
Reward(value=-1.0, reason="Second submit_property attempt", partial=False),
|
| 113 |
)
|
| 114 |
+
|
| 115 |
submitted_text = params.get("property", "").strip()
|
| 116 |
if not submitted_text:
|
| 117 |
return (
|
|
|
|
| 121 |
|
| 122 |
ctx._submitted = True
|
| 123 |
ctx._done = True
|
| 124 |
+
score, confidence = ctx._grader.grade(submitted_text)
|
| 125 |
+
reward = round(score * 5.0, 4)
|
| 126 |
|
| 127 |
+
msg = f'Score: {score:.2f}/1.00 β Confidence: {confidence:.2f}\n'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
return msg, Reward(
|
| 129 |
value=reward,
|
| 130 |
reason=f"Property submission score={score:.3f}",
|
tasks/task2/environment.py
CHANGED
|
@@ -50,7 +50,7 @@ AVAILABLE_ACTIONS = [
|
|
| 50 |
ActionType.GET_FUNCTION_NATSPEC,
|
| 51 |
ActionType.GET_FILE_NATSPEC,
|
| 52 |
ActionType.GET_RELATED_FUNCTIONS,
|
| 53 |
-
ActionType.
|
| 54 |
ActionType.GET_SIMILAR_RULE,
|
| 55 |
ActionType.SUBMIT_PROPERTY,
|
| 56 |
]
|
|
@@ -85,7 +85,7 @@ class Task2Environment(BaseEnv):
|
|
| 85 |
)
|
| 86 |
self._grader = Task2Grader(
|
| 87 |
function_name=self._target_fn["name"],
|
| 88 |
-
|
| 89 |
)
|
| 90 |
self._step_count = 0
|
| 91 |
self._cum_reward = 0.0
|
|
@@ -184,13 +184,13 @@ class Task2Environment(BaseEnv):
|
|
| 184 |
# Mapping from ActionType to handler function
|
| 185 |
# Each handler expects (ctx, qkey, params) and returns (str, Reward)
|
| 186 |
handlers = {
|
| 187 |
-
ActionType.GET_FUNCTION_CODE:
|
| 188 |
-
ActionType.GET_FUNCTION_NATSPEC:
|
| 189 |
-
ActionType.GET_FILE_NATSPEC:
|
| 190 |
-
ActionType.GET_RELATED_FUNCTIONS:
|
| 191 |
-
ActionType.
|
| 192 |
-
ActionType.GET_SIMILAR_RULE:
|
| 193 |
-
ActionType.SUBMIT_PROPERTY:
|
| 194 |
}
|
| 195 |
|
| 196 |
handler = handlers.get(at)
|
|
|
|
| 50 |
ActionType.GET_FUNCTION_NATSPEC,
|
| 51 |
ActionType.GET_FILE_NATSPEC,
|
| 52 |
ActionType.GET_RELATED_FUNCTIONS,
|
| 53 |
+
ActionType.GET_SIGNATURE,
|
| 54 |
ActionType.GET_SIMILAR_RULE,
|
| 55 |
ActionType.SUBMIT_PROPERTY,
|
| 56 |
]
|
|
|
|
| 85 |
)
|
| 86 |
self._grader = Task2Grader(
|
| 87 |
function_name=self._target_fn["name"],
|
| 88 |
+
property=self._target_fn["property"],
|
| 89 |
)
|
| 90 |
self._step_count = 0
|
| 91 |
self._cum_reward = 0.0
|
|
|
|
| 184 |
# Mapping from ActionType to handler function
|
| 185 |
# Each handler expects (ctx, qkey, params) and returns (str, Reward)
|
| 186 |
handlers = {
|
| 187 |
+
ActionType.GET_FUNCTION_CODE: actions.get_function_code,
|
| 188 |
+
ActionType.GET_FUNCTION_NATSPEC: actions.get_function_natspec,
|
| 189 |
+
ActionType.GET_FILE_NATSPEC: actions.get_file_natspec,
|
| 190 |
+
ActionType.GET_RELATED_FUNCTIONS: actions.get_related_functions_action,
|
| 191 |
+
ActionType.GET_SIGNATURE: actions.get_signature,
|
| 192 |
+
ActionType.GET_SIMILAR_RULE: actions.get_similar_rule_action,
|
| 193 |
+
ActionType.SUBMIT_PROPERTY: actions.submit_property,
|
| 194 |
}
|
| 195 |
|
| 196 |
handler = handlers.get(at)
|
tasks/task2/grader.py
CHANGED
|
@@ -2,101 +2,13 @@
|
|
| 2 |
grader.py (Task 2 β Property Discovery)
|
| 3 |
-----------------------------------------
|
| 4 |
Deterministic scorer for natural-language property submissions.
|
| 5 |
-
|
| 6 |
-
Score formula
|
| 7 |
-
βββββββββββββ
|
| 8 |
-
key_phrases weight = 0.70
|
| 9 |
-
bonus_phrases weight = 0.30
|
| 10 |
-
|
| 11 |
-
score = 0.70 * (matched_key / total_key)
|
| 12 |
-
+ 0.30 * (matched_bonus / total_bonus)
|
| 13 |
-
|
| 14 |
-
Phrase matching
|
| 15 |
-
βββββββββββββββ
|
| 16 |
-
A phrase is considered matched if ALL its words (after normalisation)
|
| 17 |
-
appear in the submitted text. This is intentionally lenient β it
|
| 18 |
-
doesn't require the words to be adjacent, so "balance increases by
|
| 19 |
-
msg.value" is matched by "the caller's vault balance increases by the
|
| 20 |
-
sent msg.value amount".
|
| 21 |
-
|
| 22 |
-
Synonym expansion allows common paraphrases to also match
|
| 23 |
-
(e.g. "caller" β "msg.sender", "sender", "user").
|
| 24 |
-
|
| 25 |
-
Terminal reward = score Γ 5.0 (range: 0.0 β 5.0)
|
| 26 |
One submission attempt per episode.
|
| 27 |
"""
|
| 28 |
|
| 29 |
from __future__ import annotations
|
| 30 |
|
| 31 |
-
import
|
| 32 |
-
from
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# ββ Text normalisation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 36 |
-
|
| 37 |
-
_PUNCT = str.maketrans("", "", string.punctuation)
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def _norm(text: str) -> str:
|
| 41 |
-
"""Lowercase, strip punctuation, collapse whitespace β word string."""
|
| 42 |
-
return " ".join(text.lower().translate(_PUNCT).split())
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _word_set(text: str) -> set:
|
| 46 |
-
"""Return normalised set of words."""
|
| 47 |
-
return set(_norm(text).split())
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# ββ Synonym table βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
-
# Maps canonical word β list of accepted synonyms (all lowercase, no punct).
|
| 52 |
-
|
| 53 |
-
SYNONYMS: Dict[str, List[str]] = {
|
| 54 |
-
"caller": ["caller", "sender", "user", "msgsender", "msg sender"],
|
| 55 |
-
"balance": ["balance", "holdings", "amount held"],
|
| 56 |
-
"increases": ["increases", "incremented", "added", "grows", "rise"],
|
| 57 |
-
"decreases": ["decreases", "decremented", "reduced", "subtracted", "falls"],
|
| 58 |
-
"transfers": ["transfers", "sends", "moved", "forwarded", "sent"],
|
| 59 |
-
"reverts": ["reverts", "fails", "rejected", "throws", "require"],
|
| 60 |
-
"zero": ["zero", "0", "nothing", "none", "empty"],
|
| 61 |
-
"owner": ["owner", "admin", "authorized"],
|
| 62 |
-
"entire": ["entire", "full", "whole", "all", "total"],
|
| 63 |
-
"returned": ["returned", "sent back", "refunded", "transferred back"],
|
| 64 |
-
"reset": ["reset", "zeroed", "set to zero", "cleared"],
|
| 65 |
-
"only": ["only", "exclusively", "restricted"],
|
| 66 |
-
"price": ["price", "cost", "rate"],
|
| 67 |
-
"tokens": ["tokens", "token amount"],
|
| 68 |
-
"rewards": ["rewards", "reward tokens", "accrued"],
|
| 69 |
-
"staked": ["staked", "deposited", "locked"],
|
| 70 |
-
"winner": ["winner", "winning bidder", "successful bidder"],
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def _expand_words(phrase_words: List[str]) -> List[List[str]]:
|
| 75 |
-
"""
|
| 76 |
-
For each word in the phrase, generate synonym variants.
|
| 77 |
-
Returns a list of word-list variants to try.
|
| 78 |
-
Only substitutes ONE word at a time to avoid combinatorial explosion.
|
| 79 |
-
"""
|
| 80 |
-
variants = [phrase_words] # original
|
| 81 |
-
for i, word in enumerate(phrase_words):
|
| 82 |
-
if word in SYNONYMS:
|
| 83 |
-
for syn in SYNONYMS[word]:
|
| 84 |
-
syn_words = _norm(syn).split()
|
| 85 |
-
new_variant = phrase_words[:i] + syn_words + phrase_words[i + 1:]
|
| 86 |
-
variants.append(new_variant)
|
| 87 |
-
return variants
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def _phrase_matched(text_words: set, phrase: str) -> bool:
|
| 91 |
-
"""
|
| 92 |
-
True if ALL words in the phrase (or a synonym variant) appear in text_words.
|
| 93 |
-
Uses word-set containment, not substring adjacency.
|
| 94 |
-
"""
|
| 95 |
-
norm_words = _norm(phrase).split()
|
| 96 |
-
for variant_words in _expand_words(norm_words):
|
| 97 |
-
if all(w in text_words for w in variant_words):
|
| 98 |
-
return True
|
| 99 |
-
return False
|
| 100 |
|
| 101 |
|
| 102 |
# ββ Grader ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
@@ -108,64 +20,22 @@ class Task2Grader:
|
|
| 108 |
Parameters
|
| 109 |
----------
|
| 110 |
function_name : name of the target function
|
| 111 |
-
|
| 112 |
-
Must have: natural_language, key_phrases, bonus_phrases
|
| 113 |
"""
|
| 114 |
|
| 115 |
-
|
| 116 |
-
BONUS_WEIGHT = 0.30
|
| 117 |
-
|
| 118 |
-
def __init__(self, function_name: str, property_data: Dict) -> None:
|
| 119 |
self.function_name = function_name
|
| 120 |
-
self.
|
| 121 |
-
self.key_phrases = property_data.get("key_phrases", [])
|
| 122 |
-
self.bonus_phrases = property_data.get("bonus_phrases", [])
|
| 123 |
|
| 124 |
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
|
| 126 |
-
def grade(self, submitted: str) -> float:
|
| 127 |
"""Deterministic score in [0.0, 1.0]."""
|
| 128 |
if not submitted or not submitted.strip():
|
| 129 |
-
return 0.0
|
| 130 |
-
tw = _word_set(submitted)
|
| 131 |
-
key_score = self._phrase_score(tw, self.key_phrases)
|
| 132 |
-
bonus_score = self._phrase_score(tw, self.bonus_phrases)
|
| 133 |
-
raw = self.KEY_WEIGHT * key_score + self.BONUS_WEIGHT * bonus_score
|
| 134 |
-
return round(min(max(raw, 0.0), 1.0), 4)
|
| 135 |
-
|
| 136 |
-
def reward_for_score(self, score: float) -> float:
|
| 137 |
-
"""Maps [0.0, 1.0] β [0.0, 5.0]."""
|
| 138 |
-
return round(score * 5.0, 4)
|
| 139 |
-
|
| 140 |
-
def breakdown(self, submitted: str) -> Dict:
|
| 141 |
-
"""Detailed scoring breakdown for debugging."""
|
| 142 |
-
tw = _word_set(submitted)
|
| 143 |
-
key_hits = [p for p in self.key_phrases if _phrase_matched(tw, p)]
|
| 144 |
-
bonus_hits = [p for p in self.bonus_phrases if _phrase_matched(tw, p)]
|
| 145 |
-
score = self.grade(submitted)
|
| 146 |
-
return {
|
| 147 |
-
"score": score,
|
| 148 |
-
"reward": self.reward_for_score(score),
|
| 149 |
-
"key_matched": key_hits,
|
| 150 |
-
"key_missed": [p for p in self.key_phrases if p not in key_hits],
|
| 151 |
-
"bonus_matched": bonus_hits,
|
| 152 |
-
"bonus_missed": [p for p in self.bonus_phrases if p not in bonus_hits],
|
| 153 |
-
"key_score": self._phrase_score(tw, self.key_phrases),
|
| 154 |
-
"bonus_score": self._phrase_score(tw, self.bonus_phrases),
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
def get_canonical_answer(self) -> Dict:
|
| 158 |
-
"""For debugging / logging only."""
|
| 159 |
-
return {
|
| 160 |
-
"function": self.function_name,
|
| 161 |
-
"natural_language": self.natural_language,
|
| 162 |
-
"key_phrases": self.key_phrases,
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
# ββ Internal ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
| 2 |
grader.py (Task 2 β Property Discovery)
|
| 3 |
-----------------------------------------
|
| 4 |
Deterministic scorer for natural-language property submissions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
One submission attempt per episode.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
| 9 |
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
from utils import SemanticMatcher
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
# ββ Grader ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 20 |
Parameters
|
| 21 |
----------
|
| 22 |
function_name : name of the target function
|
| 23 |
+
property : the 'property' field from the target function's data
|
|
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
+
def __init__(self, function_name: str, property: str) -> None:
|
|
|
|
|
|
|
|
|
|
| 27 |
self.function_name = function_name
|
| 28 |
+
self.property = property
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
|
| 32 |
+
def grade(self, submitted: str) -> Tuple[float, str]:
|
| 33 |
"""Deterministic score in [0.0, 1.0]."""
|
| 34 |
if not submitted or not submitted.strip():
|
| 35 |
+
return 0.0, "no_match"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
SemanticMatcherInstance = SemanticMatcher()
|
| 38 |
+
return (
|
| 39 |
+
SemanticMatcherInstance.matchscore(self.property, submitted),
|
| 40 |
+
SemanticMatcherInstance.confidence()
|
| 41 |
+
)
|
utils/__init__.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
# utils package
|
| 2 |
-
from utils.
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
| 1 |
# utils package
|
| 2 |
+
from utils.semanticmatcher import SemanticMatcher
|
| 3 |
+
from utils.prompts import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
|
| 4 |
+
from utils.propertyretriever import PropertyRetriever
|
| 5 |
|
| 6 |
+
__all__ = ["SemanticMatcher", "T1_SYSTEM", "T2_SYSTEM", "T3_SYSTEM", "PropertyRetriever"]
|
utils/matcher.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
-
|
| 3 |
-
def match_strings(terms: List[str], phrase: str, threshold: float = 0.6) -> bool:
|
| 4 |
-
"""
|
| 5 |
-
Determine if a short string (2-3 words) is likely present in a list of terms
|
| 6 |
-
using token set similarity
|
| 7 |
-
|
| 8 |
-
Args:
|
| 9 |
-
terms: List of term strings (preferably lowercased).
|
| 10 |
-
phrase: Input phrase (2-3 words, case-insensitive).
|
| 11 |
-
threshold: Minimum Jaccard similarity to count as a match (default 0.6).
|
| 12 |
-
|
| 13 |
-
Returns:
|
| 14 |
-
True if any term in the list has token set similarity >= threshold.
|
| 15 |
-
"""
|
| 16 |
-
phrase_lower = phrase.lower().strip()
|
| 17 |
-
phrase_tokens = set(phrase_lower.split())
|
| 18 |
-
|
| 19 |
-
for term in terms:
|
| 20 |
-
term_tokens = set(term.split())
|
| 21 |
-
# Jaccard similarity = size of intersection / size of union
|
| 22 |
-
if not phrase_tokens or not term_tokens:
|
| 23 |
-
continue
|
| 24 |
-
intersection = phrase_tokens & term_tokens
|
| 25 |
-
union = phrase_tokens | term_tokens
|
| 26 |
-
similarity = len(intersection) / len(union)
|
| 27 |
-
if similarity >= threshold:
|
| 28 |
-
return True
|
| 29 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/propertyretriever.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PropertyRetriever class for finding similar properties based on code embeddings.
|
| 3 |
+
It loads a dataset of properties, computes embeddings for their critical code sections,
|
| 4 |
+
and provides a method to retrieve the most similar property given a new code snippet.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# ! Current data contains properties from contracts.json, making it more likely to find a exact match
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
from sklearn.preprocessing import normalize
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__))
|
| 16 |
+
DEFAULT_CSV_PATH = os.path.join(DATA_DIR, "properties.csv")
|
| 17 |
+
SIMILARITY_THRESHOLD = 0.8 # Adjust as needed based on validation
|
| 18 |
+
|
| 19 |
+
# -------------------------------------------------------------------
|
| 20 |
+
# 1. Load the dataset and build the vector database (offline/once)
|
| 21 |
+
# -------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class PropertyRetriever:
|
| 24 |
+
def __init__(self):
|
| 25 |
+
"""
|
| 26 |
+
csv_path : path to the CSV file containing the columns:
|
| 27 |
+
SpecHash, SpecIndex, Type, Name, StartLine, EndLine,
|
| 28 |
+
MethodsInRule, RuleContent, RelatedFunctions,
|
| 29 |
+
FunctionBodies, FilePath, Project, ContractCode,
|
| 30 |
+
StateVarAssignment, RuleContentNL, Funcitonality
|
| 31 |
+
similarity_threshold : minimum dot product to consider a match
|
| 32 |
+
"""
|
| 33 |
+
self.df = pd.read_csv(DEFAULT_CSV_PATH)
|
| 34 |
+
self.threshold = SIMILARITY_THRESHOLD
|
| 35 |
+
|
| 36 |
+
# Use a lightweight, openβsource embedding model
|
| 37 |
+
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
| 38 |
+
|
| 39 |
+
# Extract "critical code" from each property (use FunctionBodies)
|
| 40 |
+
# Fallback to RelatedFunctions or RuleContent if FunctionBodies is missing
|
| 41 |
+
self.critical_codes = []
|
| 42 |
+
for idx, row in self.df.iterrows():
|
| 43 |
+
code = row.get('FunctionBodies', '')
|
| 44 |
+
if pd.isna(code) or code.strip() == '':
|
| 45 |
+
# Fallback: concatenate RelatedFunctions or use RuleContent
|
| 46 |
+
code = row.get('RelatedFunctions', '')
|
| 47 |
+
if pd.isna(code) or code.strip() == '':
|
| 48 |
+
code = row.get('RuleContent', '')
|
| 49 |
+
self.critical_codes.append(str(code))
|
| 50 |
+
|
| 51 |
+
# Compute embeddings for all critical codes
|
| 52 |
+
self.embeddings = self.embedder.encode(self.critical_codes, show_progress_bar=True)
|
| 53 |
+
# Normalize for dot product = cosine similarity
|
| 54 |
+
self.embeddings = normalize(self.embeddings, norm='l2')
|
| 55 |
+
|
| 56 |
+
def get_similar_property(self, input_code: str) -> str:
|
| 57 |
+
"""
|
| 58 |
+
Given a Solidity function code string, return the most similar property
|
| 59 |
+
(RuleContent) from the dataset, or an empty string if none exceeds the threshold.
|
| 60 |
+
"""
|
| 61 |
+
if not input_code or not isinstance(input_code, str):
|
| 62 |
+
return ""
|
| 63 |
+
|
| 64 |
+
# Step β‘: Embed the subject code
|
| 65 |
+
query_emb = self.embedder.encode([input_code])
|
| 66 |
+
query_emb = normalize(query_emb, norm='l2')
|
| 67 |
+
|
| 68 |
+
# Step β’: Compute dot products with all database vectors
|
| 69 |
+
similarities = np.dot(self.embeddings, query_emb.T).flatten()
|
| 70 |
+
|
| 71 |
+
# Find the best match above threshold
|
| 72 |
+
best_idx = np.argmax(similarities)
|
| 73 |
+
best_score = similarities[best_idx]
|
| 74 |
+
|
| 75 |
+
if best_score >= self.threshold:
|
| 76 |
+
# Return the property content (RuleContent) of the best match
|
| 77 |
+
return self.df.iloc[best_idx]['RuleContentNL']
|
| 78 |
+
else:
|
| 79 |
+
# No sufficiently similar property found
|
| 80 |
+
return ""
|
utils/semanticmatcher.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
semanticmatcher.py
|
| 3 |
+
====================
|
| 4 |
+
Deterministic semantic string matcher for short strings (10β12 words).
|
| 5 |
+
|
| 6 |
+
Algorithm: Weighted ensemble of three independent signals
|
| 7 |
+
1. Lexical Jaccard β lemmatized token overlap (weight: 0.20)
|
| 8 |
+
2. Synonym Jaccard β WordNet-expanded token overlap (weight: 0.25)
|
| 9 |
+
3. Semantic Cosine β sentence-transformers embedding similarity (weight: 0.55)
|
| 10 |
+
|
| 11 |
+
All three layers are fully deterministic: same inputs β same score, always.
|
| 12 |
+
|
| 13 |
+
Install dependencies:
|
| 14 |
+
python -m nltk.downloader wordnet omw-1.4 stopwords punkt punkt_tab averaged_perceptron_tagger_eng
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
import string
|
| 19 |
+
from functools import lru_cache
|
| 20 |
+
|
| 21 |
+
import nltk
|
| 22 |
+
import numpy as np
|
| 23 |
+
from nltk.corpus import wordnet, stopwords
|
| 24 |
+
from nltk.stem import WordNetLemmatizer
|
| 25 |
+
from sentence_transformers import SentenceTransformer
|
| 26 |
+
|
| 27 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
|
| 29 |
+
WEIGHTS = {
|
| 30 |
+
"lexical": 0.20, # Plain lemma overlap
|
| 31 |
+
"synonym": 0.25, # WordNet-expanded overlap
|
| 32 |
+
"semantic": 0.55, # Embedding cosine similarity
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
MATCH_THRESHOLD = 0.72 # Score β₯ this β strings "mean the same thing"
|
| 36 |
+
STRONG_THRESHOLD = 0.88 # Score β₯ this β high-confidence match
|
| 37 |
+
|
| 38 |
+
# Embedding model: deterministic, no sampling
|
| 39 |
+
_EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
| 40 |
+
|
| 41 |
+
# ββ Lazy singletons βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
+
_model: SentenceTransformer | None = None
|
| 44 |
+
_lemmatizer: WordNetLemmatizer | None = None
|
| 45 |
+
_stop_words: set[str] | None = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _get_model() -> SentenceTransformer:
|
| 49 |
+
global _model
|
| 50 |
+
if _model is None:
|
| 51 |
+
_model = SentenceTransformer(_EMBEDDING_MODEL_NAME)
|
| 52 |
+
return _model
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _get_lemmatizer() -> WordNetLemmatizer:
|
| 56 |
+
global _lemmatizer
|
| 57 |
+
if _lemmatizer is None:
|
| 58 |
+
_lemmatizer = WordNetLemmatizer()
|
| 59 |
+
return _lemmatizer
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _get_stopwords() -> set[str]:
|
| 63 |
+
global _stop_words
|
| 64 |
+
if _stop_words is None:
|
| 65 |
+
_stop_words = set(stopwords.words("english"))
|
| 66 |
+
return _stop_words
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ββ Text preprocessing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
|
| 71 |
+
def _get_wordnet_pos(treebank_tag: str) -> str:
|
| 72 |
+
"""Map POS treebank tag to WordNet POS constant for better lemmatization."""
|
| 73 |
+
if treebank_tag.startswith("J"):
|
| 74 |
+
return wordnet.ADJ
|
| 75 |
+
elif treebank_tag.startswith("V"):
|
| 76 |
+
return wordnet.VERB
|
| 77 |
+
elif treebank_tag.startswith("R"):
|
| 78 |
+
return wordnet.ADV
|
| 79 |
+
return wordnet.NOUN
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def normalize(text: str) -> str:
|
| 83 |
+
"""Lowercase, strip punctuation, collapse whitespace."""
|
| 84 |
+
text = text.lower()
|
| 85 |
+
text = text.translate(str.maketrans("", "", string.punctuation))
|
| 86 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 87 |
+
return text
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def tokenize_and_lemmatize(text: str) -> list[str]:
|
| 91 |
+
"""Tokenize, POS-tag, lemmatize, and remove stopwords."""
|
| 92 |
+
lemmatizer = _get_lemmatizer()
|
| 93 |
+
stop_words = _get_stopwords()
|
| 94 |
+
|
| 95 |
+
tokens = nltk.word_tokenize(normalize(text))
|
| 96 |
+
pos_tags = nltk.pos_tag(tokens)
|
| 97 |
+
|
| 98 |
+
lemmas = [
|
| 99 |
+
lemmatizer.lemmatize(word, _get_wordnet_pos(pos))
|
| 100 |
+
for word, pos in pos_tags
|
| 101 |
+
if word not in stop_words and word.isalpha()
|
| 102 |
+
]
|
| 103 |
+
return lemmas
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ββ WordNet synonym expansion βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
|
| 108 |
+
@lru_cache(maxsize=512)
|
| 109 |
+
def _synonyms(word: str) -> frozenset[str]:
|
| 110 |
+
"""Return all WordNet lemma names for a word (including the word itself)."""
|
| 111 |
+
syns: set[str] = {word}
|
| 112 |
+
for synset in wordnet.synsets(word):
|
| 113 |
+
for lemma in synset.lemmas(): # type: ignore
|
| 114 |
+
syns.add(lemma.name().replace("_", " ").lower())
|
| 115 |
+
return frozenset(syns)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def expand_with_synonyms(tokens: list[str]) -> set[str]:
|
| 119 |
+
"""Expand a token list to include all WordNet synonyms."""
|
| 120 |
+
expanded: set[str] = set()
|
| 121 |
+
for token in tokens:
|
| 122 |
+
expanded.update(_synonyms(token))
|
| 123 |
+
return expanded
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ββ Similarity metrics ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
|
| 128 |
+
def jaccard(set_a: set[str], set_b: set[str]) -> float:
|
| 129 |
+
"""Jaccard similarity: |A β© B| / |A βͺ B|"""
|
| 130 |
+
if not set_a and not set_b:
|
| 131 |
+
return 1.0
|
| 132 |
+
intersection = set_a & set_b
|
| 133 |
+
union = set_a | set_b
|
| 134 |
+
return len(intersection) / len(union)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def cosine_similarity(vec_a: np.ndarray, vec_b: np.ndarray) -> float:
|
| 138 |
+
"""Cosine similarity between two L2-normalized vectors."""
|
| 139 |
+
norm_a = np.linalg.norm(vec_a)
|
| 140 |
+
norm_b = np.linalg.norm(vec_b)
|
| 141 |
+
if norm_a == 0 or norm_b == 0:
|
| 142 |
+
return 0.0
|
| 143 |
+
return float(np.dot(vec_a, vec_b) / (norm_a * norm_b))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ββ Core matcher ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
+
|
| 148 |
+
class SemanticMatcher:
|
| 149 |
+
"""
|
| 150 |
+
Deterministic semantic matcher for short strings.
|
| 151 |
+
|
| 152 |
+
Usage:
|
| 153 |
+
matcher = SemanticMatcher()
|
| 154 |
+
result = matcher.match("The cat sat on the mat",
|
| 155 |
+
"A cat was sitting on the mat")
|
| 156 |
+
print(result)
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
match_threshold: float = MATCH_THRESHOLD,
|
| 162 |
+
strong_threshold: float = STRONG_THRESHOLD,
|
| 163 |
+
weights: dict[str, float] | None = None,
|
| 164 |
+
):
|
| 165 |
+
self.match_threshold = match_threshold
|
| 166 |
+
self.strong_threshold = strong_threshold
|
| 167 |
+
self.weights = weights or WEIGHTS
|
| 168 |
+
self.confidence_level: str = "no_match"
|
| 169 |
+
|
| 170 |
+
total = sum(self.weights.values())
|
| 171 |
+
assert abs(total - 1.0) < 1e-6, f"Weights must sum to 1.0 (got {total:.4f})"
|
| 172 |
+
|
| 173 |
+
# ββ Inner Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 174 |
+
|
| 175 |
+
def _layer_lexical(self, tokens_a: list[str], tokens_b: list[str]) -> float:
|
| 176 |
+
return jaccard(set(tokens_a), set(tokens_b))
|
| 177 |
+
|
| 178 |
+
def _layer_synonym(self, tokens_a: list[str], tokens_b: list[str]) -> float:
|
| 179 |
+
expanded_a = expand_with_synonyms(tokens_a)
|
| 180 |
+
expanded_b = expand_with_synonyms(tokens_b)
|
| 181 |
+
return jaccard(expanded_a, expanded_b)
|
| 182 |
+
|
| 183 |
+
def _layer_semantic(self, text_a: str, text_b: str) -> float:
|
| 184 |
+
model = _get_model()
|
| 185 |
+
# encode() is deterministic: no sampling, fixed weights
|
| 186 |
+
embeddings = model.encode(
|
| 187 |
+
[normalize(text_a), normalize(text_b)],
|
| 188 |
+
convert_to_numpy=True,
|
| 189 |
+
normalize_embeddings=True,
|
| 190 |
+
)
|
| 191 |
+
return cosine_similarity(embeddings[0], embeddings[1]) # type: ignore
|
| 192 |
+
|
| 193 |
+
# ββ Public Functions ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 194 |
+
|
| 195 |
+
def matchscore(self, text_a: str, text_b: str) -> float:
|
| 196 |
+
"""
|
| 197 |
+
Compare two strings and return a score of whether they are matching.
|
| 198 |
+
|
| 199 |
+
Returns a float between 0.0 and 1.0, where 1.0 indicates a perfect match.
|
| 200 |
+
"""
|
| 201 |
+
# Fast-path: normalized exact match
|
| 202 |
+
if normalize(text_a) == normalize(text_b):
|
| 203 |
+
return True
|
| 204 |
+
|
| 205 |
+
tokens_a = tokenize_and_lemmatize(text_a)
|
| 206 |
+
tokens_b = tokenize_and_lemmatize(text_b)
|
| 207 |
+
|
| 208 |
+
layer_scores = {
|
| 209 |
+
"lexical": self._layer_lexical(tokens_a, tokens_b),
|
| 210 |
+
"synonym": self._layer_synonym(tokens_a, tokens_b),
|
| 211 |
+
"semantic": self._layer_semantic(text_a, text_b),
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
score = sum(self.weights[k] * v for k, v in layer_scores.items())
|
| 215 |
+
if score >= self.strong_threshold:
|
| 216 |
+
self.confidence_level = "strong"
|
| 217 |
+
elif score >= self.match_threshold:
|
| 218 |
+
self.confidence_level = "moderate"
|
| 219 |
+
else:
|
| 220 |
+
self.confidence_level = "no_match"
|
| 221 |
+
return score
|
| 222 |
+
|
| 223 |
+
def match(self, text_a: str, text_b: str) -> bool:
|
| 224 |
+
"""Return True if the two texts are considered a match based on the score."""
|
| 225 |
+
score = self.matchscore(text_a, text_b)
|
| 226 |
+
return score >= self.match_threshold
|
| 227 |
+
|
| 228 |
+
def confidence(self) -> str:
|
| 229 |
+
"""Return 'strong' if score β₯ strong_threshold, else 'weak'."""
|
| 230 |
+
return self.confidence_level
|