File size: 2,482 Bytes
2633ee9
46466a5
 
7e16d4f
 
 
 
 
 
46466a5
 
7e16d4f
 
 
1f626ee
2633ee9
1f626ee
2633ee9
 
1f626ee
2633ee9
46466a5
7e16d4f
46466a5
 
 
7e16d4f
46466a5
 
 
7e16d4f
46466a5
 
7e16d4f
 
 
46466a5
7e16d4f
 
46466a5
7e16d4f
 
46466a5
7e16d4f
 
 
3caf047
 
46466a5
 
3caf047
46466a5
 
 
 
 
 
 
 
 
 
 
3caf047
 
7e16d4f
 
46466a5
7e16d4f
3caf047
 
46466a5
7e16d4f
 
 
 
 
 
 
46466a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Optional, Union

import regex as re
import weave
from pydantic import BaseModel


class RegexResult(BaseModel):
    passed: bool
    matched_patterns: dict[str, list[str]]
    failed_patterns: list[str]


class RegexModel(weave.Model):
    """
    Initialize RegexModel with a dictionary of patterns.

    Args:
        patterns (Dict[str, str]): Dictionary where key is pattern name and value is regex pattern.
    """

    patterns: Optional[Union[dict[str, str], dict[str, list[str]]]] = None

    def __init__(
        self, patterns: Optional[Union[dict[str, str], dict[str, list[str]]]] = None
    ) -> None:
        super().__init__(patterns=patterns)
        normalized_patterns = {}
        for k, v in patterns.items():
            normalized_patterns[k] = v if isinstance(v, list) else [v]
        self._compiled_patterns = {
            name: [re.compile(p) for p in pattern]
            for name, pattern in normalized_patterns.items()
        }

    @weave.op()
    def check(self, text: str) -> RegexResult:
        """
        Check text against all patterns and return detailed results.

        Args:
            text: Input text to check against patterns

        Returns:
            RegexResult containing pass/fail status and details about matches
        """
        matched_patterns = {}
        failed_patterns = []

        for pattern_name, pats in self._compiled_patterns.items():
            matches = []
            for pattern in pats:
                for match in pattern.finditer(text):
                    if match.groups():
                        # If there are capture groups, join them with a separator
                        matches.append(
                            "-".join(str(g) for g in match.groups() if g is not None)
                        )
                    else:
                        # If no capture groups, use the full match
                        matches.append(match.group(0))

            if matches:
                matched_patterns[pattern_name] = matches
            else:
                failed_patterns.append(pattern_name)

        return RegexResult(
            matched_patterns=matched_patterns,
            failed_patterns=failed_patterns,
            passed=len(matched_patterns) == 0,
        )

    @weave.op()
    def predict(self, text: str) -> RegexResult:
        """
        Alias for check() to maintain consistency with other models.
        """
        return self.check(text)