Spaces:
Running
Running
| """Expression parser for filter expressions with date filtering and sorting. | |
| Required Syntax (Full Column Names Only): | |
| gap_pct > 10 # Gap % > 10% | |
| run_pct > 20 # Run % > 20% | |
| change_pct > 5 # Change % > 5% | |
| volume > 5M # Volume > 5,000,000 | |
| gap_pct > 10 in 5d # Gap > 10% in last 5 days | |
| $close[1] # Close 1 day after event | |
| close[-1] # Yesterday's close | |
| max(high, 20) # 20-day max (excludes today) | |
| volume > 5M sort volume desc | |
| """ | |
| import re | |
| from dataclasses import dataclass, field | |
| class FilterCondition: | |
| """A single filter condition (e.g., date >= '2026-01-01').""" | |
| column: str | |
| operator: str | |
| value: str | |
| class SortSpec: | |
| """Sort specification.""" | |
| column: str # Full expression (e.g., "close", "close / close[-10]") | |
| direction: str = "asc" | |
| class ParsedExpression: | |
| """Result of parsing a filter expression.""" | |
| date_conditions: list[FilterCondition] = field(default_factory=list) | |
| sort: SortSpec | None = None | |
| remaining_filter: str = "" # For pandas evaluation | |
| # Detected features for SQL generation | |
| metrics: set[tuple[str, int]] = field(default_factory=set) # (col, offset) | |
| aggregations: set[tuple[str, str, int]] = field(default_factory=set) # (func, col, lookback) | |
| window_chained: set[tuple[str, str, int, int]] = field(default_factory=set) # (func, col, lookback, offset) | |
| chained_aggs: set[tuple[str, str, str, int]] = field(default_factory=set) # (func, col1, col2, offset) | |
| binary_aggs: set[tuple[str, str, str]] = field(default_factory=set) # (func, col1, col2) | |
| # Event features: in N days and $column[offset] | |
| event_windows: list[tuple[str, int]] = field(default_factory=list) # [(condition, days)] | |
| event_refs: list[tuple[str, int]] = field(default_factory=list) # [(column, offset)] | |
| def get_start_date(self) -> str | None: | |
| """Extract start date from conditions (>= or >).""" | |
| for cond in self.date_conditions: | |
| if cond.operator in (">=", ">"): | |
| return cond.value | |
| return None | |
| def get_end_date(self) -> str | None: | |
| """Extract end date from conditions (<= or <).""" | |
| for cond in self.date_conditions: | |
| if cond.operator in ("<=", "<"): | |
| return cond.value | |
| return None | |
| def get_exact_date(self) -> str | None: | |
| """Extract exact date from conditions (= or ==).""" | |
| for cond in self.date_conditions: | |
| if cond.operator in ("=", "=="): | |
| return cond.value | |
| return None | |
| class ExpressionParser: | |
| """Parse filter expressions into structured data for SQL and Pandas. | |
| Full column names required: | |
| - gap_pct, run_pct, change_pct, range_pct (percentage columns) | |
| - gap_dollar, run_dollar, change_dollar, range_dollar (dollar columns) | |
| - volume (not vol) | |
| - streak_run_pct, rel_vol, vol_ratio_52wk | |
| - up_streak, down_streak | |
| """ | |
| # Structural patterns | |
| METRIC_PATTERN = r"(\w+)\[(-?\d+)\]" | |
| AGG_PATTERN = r"(max|min|avg)\((\w+),\s*(\d+)\)" | |
| WINDOW_CHAINED_PATTERN = r"(max|min|avg)\((\w+),\s*(\d+)\)\[(-?\d+)\]" | |
| CHAINED_AGG_PATTERN = r"(max|min|avg)\(([\w]+),\s*([\w]+)\)\[(-?\d+)\]" | |
| BINARY_AGG_PATTERN = r"(max|min|avg)\(([\w]+),\s*([a-zA-Z_][\w]*)\)(?!\s*\[)" | |
| # UI/UX Patterns | |
| SORT_PATTERN = r"(?:^|\s+)sort\s+(.+?)(?:\s+(asc|desc))?\s*$" | |
| DATE_PATTERN = r'date\s*(>=|<=|==|>|<|=)\s*[\'"](\d{4}-\d{2}-\d{2})[\'"]' | |
| # Event patterns | |
| EVENT_WINDOW_PATTERN = r"\s+in\s+(\d+)d\b" | |
| EVENT_REF_PATTERN = r"\$([a-z_][\w]*)\[(-?\d+)\]" | |
| EVENT_REF_SIMPLE = r"\$([a-z_][\w]+)" | |
| # Percentage columns (full names only) | |
| PCT_COLS = ["gap_pct", "run_pct", "change_pct", "range_pct", "streak_run_pct"] | |
| def parse(self, expr: str) -> ParsedExpression: | |
| """Fully parse expression into structured features.""" | |
| if not expr or not expr.strip(): | |
| return ParsedExpression() | |
| # Normalize trader syntax (lowercase, K/M/B suffixes) | |
| processed = self._normalize(expr) | |
| # Structural parsing | |
| sort_spec = self._parse_sort(processed) | |
| expr_without_sort = self._remove_sort(processed) | |
| date_conditions = self._parse_dates(expr_without_sort) | |
| remaining = self._remove_dates(expr_without_sort) | |
| remaining = self._cleanup_expression(remaining) | |
| # Convert percentages without % suffix to decimal | |
| remaining = self._convert_percentages(remaining) | |
| # Convert % suffix to decimal: gap_pct > 10% → gap_pct > 0.10 | |
| remaining = self._convert_percent_suffix(remaining) | |
| # Convert $amount: gap_pct > $5 → gap_dollar > 5 | |
| remaining = self._convert_dollar(remaining) | |
| # Extract features | |
| metrics = {(m, int(off)) for m, off in re.findall(self.METRIC_PATTERN, remaining)} | |
| aggs = {(f.lower(), m, int(lb)) for f, m, lb in re.findall(self.AGG_PATTERN, remaining, re.IGNORECASE)} | |
| window_chained = { | |
| (f.lower(), c, int(lb), int(off)) | |
| for f, c, lb, off in re.findall(self.WINDOW_CHAINED_PATTERN, remaining, re.IGNORECASE) | |
| } | |
| chained_aggs = { | |
| (f.lower(), c1, c2, int(off)) | |
| for f, c1, c2, off in re.findall(self.CHAINED_AGG_PATTERN, remaining, re.IGNORECASE) | |
| } | |
| binary_aggs = { | |
| (f.lower(), c1, c2) for f, c1, c2 in re.findall(self.BINARY_AGG_PATTERN, remaining, re.IGNORECASE) | |
| } | |
| # Extract event features | |
| remaining, event_windows, event_refs = self._extract_event_features(remaining) | |
| return ParsedExpression( | |
| date_conditions=date_conditions, | |
| sort=sort_spec, | |
| remaining_filter=remaining, | |
| metrics=metrics, | |
| aggregations=aggs, | |
| window_chained=window_chained, | |
| chained_aggs=chained_aggs, | |
| binary_aggs=binary_aggs, | |
| event_windows=event_windows, | |
| event_refs=event_refs, | |
| ) | |
| def _normalize(self, expr: str) -> str: | |
| """Normalize trader syntax: lowercase, K/M/B suffixes.""" | |
| # Protect string literals first | |
| temp_strings = {} | |
| result = expr | |
| for counter, match in enumerate(re.finditer(r"'(?:[^'\\]|\\.)*'|\"(?:[^\"\\]|\\.)*\"", result)): | |
| placeholder = f"__STR_{counter}__" | |
| temp_strings[placeholder] = match.group(0) | |
| result = result.replace(match.group(0), placeholder) | |
| # Lowercase | |
| result = result.lower() | |
| # K/M/B suffixes for numbers (not after $) | |
| def replace_suffix(match): | |
| num = int(match.group(1)) | |
| suffix = match.group(2) | |
| multipliers = {"k": 1_000, "m": 1_000_000, "b": 1_000_000_000} | |
| return str(num * multipliers.get(suffix, 1)) | |
| result = re.sub(r"\b(\d+)([kmb])\b(?!\s*\[)", replace_suffix, result, flags=re.IGNORECASE) | |
| # Restore string literals | |
| for placeholder, original in temp_strings.items(): | |
| result = result.replace(placeholder.lower(), original) | |
| return result | |
| def _convert_dollar(self, expr: str) -> str: | |
| """Convert $number to _dollar column for percentage columns. | |
| Example: gap_pct > $5 → gap_dollar > 5 | |
| """ | |
| def dollar_to_column(match): | |
| col = match.group(1) | |
| op = match.group(2) | |
| num = match.group(3) | |
| dollar_col = col.replace("_pct", "_dollar") | |
| return f"{dollar_col} {op} {num}" | |
| # Match: column op $number (where column is a pct column) | |
| pct_cols_pattern = r"\b(gap_pct|run_pct|change_pct|range_pct)\s*([><=!]+)\s*\$\s*(\d+\.?\d*)" | |
| expr = re.sub(pct_cols_pattern, dollar_to_column, expr, flags=re.IGNORECASE) | |
| return expr | |
| def _convert_percentages(self, expr: str) -> str: | |
| """Convert percentage columns without % suffix to decimal. | |
| Rules: | |
| - gap_pct > 5 → gap_pct > 0.05 (5%, divide by 100 because 5 > 1) | |
| - gap_pct > 5% → gap_pct > 0.05 (already converted by % handling) | |
| - gap_pct > 0.10 → gap_pct > 0.10 (already decimal, no conversion) | |
| - gap_pct > 0.05 → gap_pct > 0.05 (already decimal, no conversion) | |
| Values > 1 are divided by 100. | |
| Values <= 1 or with decimal point are kept as-is. | |
| Values with % suffix are handled separately. | |
| """ | |
| pct_cols_pattern = "|".join(self.PCT_COLS) | |
| def convert_to_decimal(match): | |
| col = match.group(1) | |
| offset = match.group(2) if match.group(2) else "" # Handle None when no offset | |
| op = match.group(3) | |
| num = match.group(4) | |
| if "." in num: | |
| return f"{col}{offset} {op} {num}" | |
| if float(num) <= 1: | |
| return f"{col}{offset} {op} {num}" | |
| decimal = float(num) / 100 | |
| return f"{col}{offset} {op} {decimal}" | |
| # Pattern: column[offset] operator number (not followed by %) | |
| # Support negative numbers and offsets | |
| pattern = rf"\b({pct_cols_pattern})(\[-?\d+\])?\s*([><=!]+)\s*(-?\d+\.?\d*)\b(?!\s*%)" | |
| expr = re.sub(pattern, convert_to_decimal, expr, flags=re.IGNORECASE) | |
| return expr | |
| def _convert_percent_suffix(self, expr: str) -> str: | |
| """Convert percentage columns with % suffix to decimal. | |
| Examples: | |
| - gap_pct > 10% → gap_pct > 0.10 | |
| - gap_pct[-1] > 5% → gap_pct[-1] > 0.05 | |
| - gap_pct > %5 → gap_pct > 0.05 (old format) | |
| """ | |
| def percent_to_decimal(match): | |
| col = match.group(1) | |
| offset = match.group(2) if match.group(2) else "" # Handle None when no offset | |
| op = match.group(3) | |
| num = float(match.group(4)) | |
| decimal = num / 100 | |
| return f"{col}{offset} {op} {decimal}" | |
| pct_cols_pattern = "|".join(self.PCT_COLS) | |
| # New format: column op number% (e.g., gap_pct > 10%) | |
| pct_cols_pattern_new = rf"\b({pct_cols_pattern})(\[-?\d+\])?\s*([><=!]+)\s*(-?\d+\.?\d*)\s*%" | |
| expr = re.sub(pct_cols_pattern_new, percent_to_decimal, expr, flags=re.IGNORECASE) | |
| # Old format: column op %number (e.g., gap_pct > %5) | |
| pct_cols_pattern_old = rf"\b({pct_cols_pattern})(\[-?\d+\])?\s*([><=!]+)\s*%\s*(-?\d+\.?\d*)" | |
| expr = re.sub(pct_cols_pattern_old, percent_to_decimal, expr, flags=re.IGNORECASE) | |
| return expr | |
| def _parse_sort(self, expr: str) -> SortSpec | None: | |
| match = re.search(self.SORT_PATTERN, expr, re.IGNORECASE) | |
| if match: | |
| return SortSpec(column=match.group(1), direction=(match.group(2) or "asc").lower()) | |
| return None | |
| def _remove_sort(self, expr: str) -> str: | |
| return re.sub(self.SORT_PATTERN, "", expr, flags=re.IGNORECASE) | |
| def _parse_dates(self, expr: str) -> list[FilterCondition]: | |
| matches = re.findall(self.DATE_PATTERN, expr, re.IGNORECASE) | |
| return [FilterCondition(column="date", operator=op, value=val) for op, val in matches] | |
| def _remove_dates(self, expr: str) -> str: | |
| result = re.sub(r"\s+and\s+" + self.DATE_PATTERN, "", expr, flags=re.IGNORECASE) | |
| result = re.sub(self.DATE_PATTERN + r"\s+and\s+", "", result, flags=re.IGNORECASE) | |
| result = re.sub(self.DATE_PATTERN, "", result, flags=re.IGNORECASE) | |
| return result | |
| def _cleanup_expression(self, expr: str) -> str: | |
| result = re.sub(r"^\s*and\s+", "", expr, flags=re.IGNORECASE) | |
| result = re.sub(r"\s+and\s*$", "", expr, flags=re.IGNORECASE) | |
| result = re.sub(r"\s+and\s+and\s+", " and ", result, flags=re.IGNORECASE) | |
| result = re.sub(r"\band\b", "and", result, flags=re.IGNORECASE) | |
| result = re.sub(r"\bor\b", "or", result, flags=re.IGNORECASE) | |
| return " ".join(result.split()) | |
| def _extract_event_features(self, expr: str) -> tuple[str, list[tuple[str, int]], list[tuple[str, int]]]: | |
| """Extract event windows and event refs from expression.""" | |
| windows = [] | |
| refs = [] | |
| remaining = expr | |
| # Split by OR | |
| or_parts = re.split(r"\s+or\s+", remaining, flags=re.IGNORECASE) | |
| remaining_or_parts = [] | |
| for or_part in or_parts: | |
| # Look for "in Nd" pattern | |
| match = re.search(self.EVENT_WINDOW_PATTERN, or_part, re.IGNORECASE) | |
| if match: | |
| # Extract condition before "in Nd" | |
| condition = or_part[: match.start()].strip() | |
| days = int(match.group(1)) | |
| # Remove leading/trailing parentheses | |
| condition = re.sub(r"^\((.+)\)$", r"\1", condition) | |
| if condition: | |
| windows.append((condition, days)) | |
| # Extract remaining after "in Nd" | |
| after_event = or_part[match.end() :].strip() | |
| if after_event.startswith("and "): | |
| after_event = after_event[4:].strip() | |
| # Split by AND for additional conditions | |
| if after_event: | |
| and_parts = re.split(r"\s+and\s+", after_event, flags=re.IGNORECASE) | |
| for part in and_parts: | |
| if part.strip(): | |
| remaining_or_parts.append(part.strip()) | |
| else: | |
| remaining_or_parts.append(or_part) | |
| remaining = " or ".join(remaining_or_parts) | |
| # Extract event refs | |
| for match in re.finditer(self.EVENT_REF_PATTERN, remaining, re.IGNORECASE): | |
| col, off = match.group(1).lower(), int(match.group(2)) | |
| if (col, off) not in refs: | |
| refs.append((col, off)) | |
| for match in re.finditer(self.EVENT_REF_SIMPLE, remaining, re.IGNORECASE): | |
| col = match.group(1).lower() | |
| if (col, 0) not in refs: | |
| refs.append((col, 0)) | |
| # Clean up | |
| remaining = remaining.strip() | |
| remaining = re.sub(r"\s+", " ", remaining) | |
| return remaining, windows, refs | |
| def extract_lookback(self, expr_or_parsed: str | ParsedExpression) -> int: | |
| """Calculate required lookback days from expression or parsed object.""" | |
| parsed = self.parse(expr_or_parsed) if isinstance(expr_or_parsed, str) else expr_or_parsed | |
| lookback = 0 | |
| if parsed.metrics: | |
| lookback = max(lookback, max(abs(off) for _, off in parsed.metrics)) | |
| if parsed.aggregations: | |
| lookback = max(lookback, max(lb for _, _, lb in parsed.aggregations)) | |
| if parsed.window_chained: | |
| lookback = max(lookback, max(lb + abs(off) for _, _, lb, off in parsed.window_chained)) | |
| if parsed.chained_aggs: | |
| lookback = max(lookback, max(abs(off) for _, _, _, off in parsed.chained_aggs)) | |
| return lookback + 5 | |
| def compile_safe(self, expr_str: str): | |
| """Compile an expression string into a safe, callable Python function. | |
| Validates the expression using Python's AST to ensure only allowed | |
| nodes and names are used, preventing code injection. | |
| """ | |
| import ast | |
| if not expr_str: | |
| return lambda ctx: False | |
| # 1. Full normalization and cleanup using existing parser logic | |
| parsed = self.parse(expr_str) | |
| # Use the remaining_filter which has dates and sort removed | |
| normalized = parsed.remaining_filter | |
| # Remove $ from event refs for AST validation (e.g., $close -> close) | |
| normalized = normalized.replace("$", "") | |
| if not normalized or not normalized.strip(): | |
| return lambda ctx: True | |
| # Handle some edge cases with 'and/or' and whitespace for Python AST | |
| normalized = re.sub(r"\band\b", " and ", normalized, flags=re.IGNORECASE) | |
| normalized = re.sub(r"\bor\b", " or ", normalized, flags=re.IGNORECASE) | |
| try: | |
| tree = ast.parse(normalized, mode="eval") | |
| except SyntaxError as e: | |
| raise ValueError(f"Invalid filter: Invalid expression syntax: {e}") from e | |
| # Allowed AST nodes for stock scanning and backtesting | |
| allowed_nodes = { | |
| ast.Expression, | |
| ast.BinOp, | |
| ast.UnaryOp, | |
| ast.Compare, | |
| ast.BoolOp, | |
| ast.Name, | |
| ast.Constant, | |
| ast.Subscript, | |
| ast.Slice, # For open[-1] | |
| ast.Call, | |
| ast.Attribute, | |
| # Operators | |
| ast.Add, | |
| ast.Sub, | |
| ast.Mult, | |
| ast.Div, | |
| ast.Mod, | |
| ast.Pow, | |
| ast.And, | |
| ast.Or, | |
| ast.Not, | |
| ast.Eq, | |
| ast.NotEq, | |
| ast.Lt, | |
| ast.LtE, | |
| ast.Gt, | |
| ast.GtE, | |
| ast.In, | |
| ast.NotIn, | |
| ast.USub, | |
| ast.UAdd, | |
| ast.Load, | |
| ast.Index, # Required for Python < 3.9 | |
| } | |
| # Prohibited variable/function names (security blocklist) | |
| # Note: 'open' is excluded because it's a valid OHLC stock metric | |
| prohibited_names = { | |
| "eval", "exec", "compile", "__import__", "getattr", "setattr", "delattr", | |
| "hasattr", "globals", "locals", "vars", "dir", "input", "breakpoint", | |
| "exit", "quit", "help", "repr", "str", "int", "float", "list", "dict", "set", | |
| "tuple", "type", "object", "class", "def", "return", "yield", "raise", "assert", | |
| "import", "from", "global", "nonlocal", "try", "except", "finally", "with", "as", | |
| "if", "else", "elif", "for", "while", "pass", "continue", "break", "del", | |
| } | |
| # Allowed variable/function names (Extensive whitelist) | |
| allowed_names = { | |
| # Metrics | |
| "open", | |
| "high", | |
| "low", | |
| "close", | |
| "volume", | |
| "price", | |
| "time", | |
| "gap_pct", | |
| "run_pct", | |
| "change_pct", | |
| "range_pct", | |
| "rel_vol", | |
| "gap_dollar", | |
| "run_dollar", | |
| "change_dollar", | |
| "range_dollar", | |
| "volume_dollar", | |
| "streak_run_pct", | |
| "vol_ratio_52wk", | |
| "up_streak", | |
| "down_streak", | |
| "rs", | |
| # Metadata | |
| "sector", | |
| "industry", | |
| "market_cap", | |
| "country", | |
| "name", | |
| "symbol", | |
| "date", | |
| # Functions | |
| "max", | |
| "min", | |
| "avg", | |
| "ret", | |
| "entry_price", | |
| # Constants | |
| "true", | |
| "false", | |
| "none", | |
| } | |
| for node in ast.walk(tree): | |
| if type(node) not in allowed_nodes: | |
| raise ValueError(f"Invalid filter: Prohibited expression element: {type(node).__name__}") | |
| if isinstance(node, ast.Name): | |
| name_id = node.id.lower() | |
| if name_id in allowed_names or name_id.startswith("__str_"): | |
| pass # Whitelisted — skip further checks | |
| elif name_id in prohibited_names: | |
| raise ValueError(f"Invalid filter: Prohibited expression element: {node.id}") | |
| else: | |
| raise ValueError(f"Invalid filter: Prohibited variable name: {node.id}") | |
| if ( | |
| isinstance(node, ast.Call) | |
| and isinstance(node.func, ast.Name) | |
| and node.func.id.lower() in prohibited_names | |
| ): | |
| raise ValueError(f"Invalid filter: Prohibited expression element: {node.func.id}") | |
| if ( | |
| isinstance(node, ast.Call) | |
| and isinstance(node.func, ast.Name) | |
| and node.func.id.lower() not in allowed_names | |
| ): | |
| raise ValueError(f"Invalid filter: Prohibited function call: {node.func.id}") | |
| # Block attribute access on builtins that could be dangerous (e.g., open.__globals__) | |
| if isinstance(node, ast.Attribute): | |
| dangerous_attrs = {"__globals__", "__builtins__", "__class__", "__dict__", | |
| "__module__", "__doc__", "__code__", "__defaults__", | |
| "__kwdefaults__", "__annotations__", "__closure__"} | |
| if node.attr in dangerous_attrs: | |
| raise ValueError(f"Invalid filter: Prohibited expression element: {node.attr}") | |
| # If we get here, it's safe to compile | |
| try: | |
| compiled = compile(tree, "<string>", "eval") | |
| return compiled | |
| except Exception as e: | |
| raise ValueError(f"Invalid filter: Compilation failed: {e}") from e | |