def cleanup_code( code: str, language_type: str = None, dataset: str = None, issft: bool = False, stop_words=[], ): """ Cleans up the generated code. """ if language_type.lower() == "python": if issft: code = _clean_python_code_for_sft(code) stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] code = _truncate_code_at_stopwords(code, stop_words) elif language_type.lower() == "ts": code = _truncate_code_at_stopwords( code, stop_words + [ "\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log", ], ) else: code = _truncate_code_at_stopwords(code, stop_words) return code def _clean_python_code_for_sft(code): code = code.replace("\r", "") if "```python" in code: code_start_idx = code.index("```python") code = code[code_start_idx:].replace("```python", "").strip() end_idx = code.find("```") if "```" in code else len(code) code = code[:end_idx].strip() return code def _truncate_code_at_stopwords(code, stop_words): min_stop_idx = len(code) for stop_word in stop_words: stop_index = code.find(stop_word) if 0 <= stop_index < min_stop_idx: min_stop_idx = stop_index return code[:min_stop_idx]