Spaces:
Runtime error
Runtime error
File size: 3,726 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import re
from typing import Callable, List
from langchain.document_loaders.parsers.language.code_segmenter import CodeSegmenter
class CobolSegmenter(CodeSegmenter):
"""Code segmenter for `COBOL`."""
PARAGRAPH_PATTERN = re.compile(r"^[A-Z0-9\-]+(\s+.*)?\.$", re.IGNORECASE)
DIVISION_PATTERN = re.compile(
r"^\s*(IDENTIFICATION|DATA|PROCEDURE|ENVIRONMENT)\s+DIVISION.*$", re.IGNORECASE
)
SECTION_PATTERN = re.compile(r"^\s*[A-Z0-9\-]+\s+SECTION.$", re.IGNORECASE)
def __init__(self, code: str):
super().__init__(code)
self.source_lines: List[str] = self.code.splitlines()
def is_valid(self) -> bool:
# Identify presence of any division to validate COBOL code
return any(self.DIVISION_PATTERN.match(line) for line in self.source_lines)
def _extract_code(self, start_idx: int, end_idx: int) -> str:
return "\n".join(self.source_lines[start_idx:end_idx]).rstrip("\n")
def _is_relevant_code(self, line: str) -> bool:
"""Check if a line is part of the procedure division or a relevant section."""
if "PROCEDURE DIVISION" in line.upper():
return True
# Add additional conditions for relevant sections if needed
return False
def _process_lines(self, func: Callable) -> List[str]:
"""A generic function to process COBOL lines based on provided func."""
elements: List[str] = []
start_idx = None
inside_relevant_section = False
for i, line in enumerate(self.source_lines):
if self._is_relevant_code(line):
inside_relevant_section = True
if inside_relevant_section and (
self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
or self.SECTION_PATTERN.match(line.strip())
):
if start_idx is not None:
func(elements, start_idx, i)
start_idx = i
# Handle the last element if exists
if start_idx is not None:
func(elements, start_idx, len(self.source_lines))
return elements
def extract_functions_classes(self) -> List[str]:
def extract_func(elements: List[str], start_idx: int, end_idx: int) -> None:
elements.append(self._extract_code(start_idx, end_idx))
return self._process_lines(extract_func)
def simplify_code(self) -> str:
simplified_lines: List[str] = []
inside_relevant_section = False
omitted_code_added = (
False # To track if "* OMITTED CODE *" has been added after the last header
)
for line in self.source_lines:
is_header = (
"PROCEDURE DIVISION" in line
or "DATA DIVISION" in line
or "IDENTIFICATION DIVISION" in line
or self.PARAGRAPH_PATTERN.match(line.strip().split(" ")[0])
or self.SECTION_PATTERN.match(line.strip())
)
if is_header:
inside_relevant_section = True
# Reset the flag since we're entering a new section/division or
# paragraph
omitted_code_added = False
if inside_relevant_section:
if is_header:
# Add header and reset the omitted code added flag
simplified_lines.append(line)
elif not omitted_code_added:
# Add omitted code comment only if it hasn't been added directly
# after the last header
simplified_lines.append("* OMITTED CODE *")
omitted_code_added = True
return "\n".join(simplified_lines)
|