|
def cleanup_code( |
|
code: str, |
|
language_type: str = None, |
|
dataset: str = None, |
|
issft: bool = False, |
|
stop_words=[], |
|
): |
|
""" |
|
Cleans up the generated code. |
|
""" |
|
|
|
if language_type.lower() == "python": |
|
if issft: |
|
code = _clean_python_code_for_sft(code) |
|
stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] |
|
code = _truncate_code_at_stopwords(code, stop_words) |
|
elif language_type.lower() == "ts": |
|
code = _truncate_code_at_stopwords( |
|
code, |
|
stop_words |
|
+ [ |
|
"\nexport", |
|
"\nimport", |
|
"\nexport default", |
|
"\nimport default", |
|
"\nconsole.log", |
|
], |
|
) |
|
else: |
|
code = _truncate_code_at_stopwords(code, stop_words) |
|
|
|
return code |
|
|
|
|
|
def _clean_python_code_for_sft(code): |
|
code = code.replace("\r", "") |
|
if "```python" in code: |
|
code_start_idx = code.index("```python") |
|
code = code[code_start_idx:].replace("```python", "").strip() |
|
end_idx = code.find("```") if "```" in code else len(code) |
|
code = code[:end_idx].strip() |
|
|
|
return code |
|
|
|
|
|
def _truncate_code_at_stopwords(code, stop_words): |
|
min_stop_idx = len(code) |
|
for stop_word in stop_words: |
|
stop_index = code.find(stop_word) |
|
if 0 <= stop_index < min_stop_idx: |
|
min_stop_idx = stop_index |
|
return code[:min_stop_idx] |
|
|