Spaces:
Runtime error
Runtime error
| import re | |
| from typing import Callable, List | |
| from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter | |
| class CobolSegmenter(CodeSegmenter): | |
| """Code segmenter for `COBOL`.""" | |
| PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE) | |
| DIVISION_PATTERN = re.compile( | |
| r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE | |
| ) | |
| SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE) | |
| def __init__(self, code: str): | |
| super().__init__(code) | |
| self.source_lines: List[str] = self.code.splitlines() | |
| def is_valid(self) -> bool: | |
| # Identify presence of any division to validate COBOL code | |
| return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines) | |
| def _extract_code(self, start_idx: int, end_idx: int) -> str: | |
| return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n") | |
| def _is_relevant_code(self, line: str) -> bool: | |
| """Check if a line is part of the procedure division or a relevant section.""" | |
| if "PROCEDURE DIVISION" in line.upper(): | |
| return True | |
| # Add additional conditions for relevant sections if needed | |
| return False | |
| def _process_lines(self, func: Callable) -> List[str]: | |
| """A generic function to process COBOL lines based on provided func.""" | |
| elements: List[str] = [] | |
| start_idx = None | |
| inside_relevant_section = False | |
| for i, line in enumerate(self.source_lines): | |
| if self._is_relevant_code(line): | |
| inside_relevant_section = True | |
| if inside_relevant_section and ( | |
| self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0]) | |
| or self.SECTION_PATTERN.match(line.strip()) | |
| ): | |
| if start_idx is not None: | |
| func(elements, start_idx, i) | |
| start_idx = i | |
| # Handle the last element if exists | |
| if start_idx is not None: | |
| func(elements, start_idx, len(self.source_lines)) | |
| return elements | |
| def extract_functions_classes(self) -> List[str]: | |
| def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None: | |
| elements.append(self._extract_code(start_idx, end_idx)) | |
| return self._process_lines(extract_func) | |
| def simplify_code(self) -> str: | |
| simplified_lines: List[str] = [] | |
| inside_relevant_section = False | |
| omitted_code_added = ( | |
| False # To track if "* OMITTED CODE *" has been added after the last header | |
| ) | |
| for line in self.source_lines: | |
| is_header = ( | |
| "PROCEDURE DIVISION" in line | |
| or "DATA DIVISION" in line | |
| or "IDENTIFICATION DIVISION" in line | |
| or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0]) | |
| or self.SECTION_PATTERN.match(line.strip()) | |
| ) | |
| if is_header: | |
| inside_relevant_section = True | |
| # Reset the flag since we're entering a new section/division or | |
| # paragraph | |
| omitted_code_added = False | |
| if inside_relevant_section: | |
| if is_header: | |
| # Add header and reset the omitted code added flag | |
| simplified_lines.append(line) | |
| elif not omitted_code_added: | |
| # Add omitted code comment only if it hasn't been added directly | |
| # after the last header | |
| simplified_lines.append("* OMITTED CODE *") | |
| omitted_code_added = True | |
| return "\n".join(simplified_lines) | |