File size: 3,726 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import re
from typing import Callable, List

from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter


class CobolSegmenter(CodeSegmenter):
    """Code segmenter for `COBOL`."""

    PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
    DIVISION_PATTERN = re.compile(
        r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
    )
    SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)

    def __init__(self, code: str):
        super().__init__(code)
        self.source_lines: List[str] = self.code.splitlines()

    def is_valid(self) -> bool:
        # Identify presence of any division to validate COBOL code
        return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)

    def _extract_code(self, start_idx: int, end_idx: int) -> str:
        return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")

    def _is_relevant_code(self, line: str) -> bool:
        """Check if a line is part of the procedure division or a relevant section."""
        if "PROCEDURE DIVISION" in line.upper():
            return True
        # Add additional conditions for relevant sections if needed
        return False

    def _process_lines(self, func: Callable) -> List[str]:
        """A generic function to process COBOL lines based on provided func."""
        elements: List[str] = []
        start_idx = None
        inside_relevant_section = False

        for i, line in enumerate(self.source_lines):
            if self._is_relevant_code(line):
                inside_relevant_section = True

            if inside_relevant_section and (
                self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
                or self.SECTION_PATTERN.match(line.strip())
            ):
                if start_idx is not None:
                    func(elements, start_idx, i)
                start_idx = i

        # Handle the last element if exists
        if start_idx is not None:
            func(elements, start_idx, len(self.source_lines))

        return elements

    def extract_functions_classes(self) -> List[str]:
        def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
            elements.append(self._extract_code(start_idx, end_idx))

        return self._process_lines(extract_func)

    def simplify_code(self) -> str:
        simplified_lines: List[str] = []
        inside_relevant_section = False
        omitted_code_added = (
            False  # To track if "* OMITTED CODE *" has been added after the last header
        )

        for line in self.source_lines:
            is_header = (
                "PROCEDURE DIVISION" in line
                or "DATA DIVISION" in line
                or "IDENTIFICATION DIVISION" in line
                or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
                or self.SECTION_PATTERN.match(line.strip())
            )

            if is_header:
                inside_relevant_section = True
                # Reset the flag since we're entering a new section/division or
                # paragraph
                omitted_code_added = False

            if inside_relevant_section:
                if is_header:
                    # Add header and reset the omitted code added flag
                    simplified_lines.append(line)
                elif not omitted_code_added:
                    # Add omitted code comment only if it hasn't been added directly
                    # after the last header
                    simplified_lines.append("* OMITTED CODE *")
                    omitted_code_added = True

        return "\n".join(simplified_lines)