|
import re |
|
|
|
languge_settings = { |
|
'python': { |
|
'full_name': 'Python', |
|
'indent': 4, |
|
}, |
|
'cpp': { |
|
'full_name': 'cpp', |
|
'indent': 0, |
|
'main': "int main()", |
|
}, |
|
'java': { |
|
'full_name': 'Java', |
|
'indent': 4, |
|
'main': "public static void main", |
|
}, |
|
'cs': { |
|
'full_name': "csharp", |
|
'indent': 0, |
|
'main': "public static void Main", |
|
}, |
|
'php': { |
|
'full_name': "PHP", |
|
'indent': 0, |
|
}, |
|
'ts': { |
|
'full_name': "TypeScript", |
|
'indent': 0, |
|
}, |
|
'js': { |
|
'full_name': "JavaScript", |
|
'indent': 0 |
|
}, |
|
'sh': { |
|
'full_name': "Bash", |
|
'indent': 0 |
|
} |
|
} |
|
|
|
def get_function_name(question: str, lang: str): |
|
func_lines = [x for x in question.strip().split('\n') if x.strip()] |
|
|
|
if lang.lower() == 'python': |
|
func_idx = [i for i in range(len(func_lines)) if func_lines[i].startswith("def ")][-1] |
|
func_name = func_lines[func_idx].split('(')[0].strip() |
|
func_prefix = "\n".join(func_lines[:func_idx]) |
|
return func_name, func_prefix |
|
|
|
func_name = func_lines[-1].split('{')[0].strip() |
|
func_prefix = "\n".join(func_lines[:-1]) |
|
return func_name, func_prefix |
|
|
|
def extract_generation_code(example: str, lang_code: str, verbose: bool=False): |
|
task_id = example['task_id'] |
|
output = example.get('output', example.get("gpt_completion")) |
|
question = example["prompt"].strip() |
|
setting = languge_settings[lang_code] |
|
lang = setting['full_name'] |
|
indent = setting['indent'] |
|
|
|
try: |
|
code_block: str = re.findall(f'```{lang.lower()}\n(.*?)```', output, re.DOTALL | re.IGNORECASE)[0] |
|
if verbose: |
|
print(">>> Task: {}\n{}".format(task_id, code_block)) |
|
|
|
|
|
if setting.get('main', None) and setting['main'] in code_block: |
|
main_start = code_block.index(setting['main']) |
|
code_block = code_block[:main_start] |
|
|
|
func_name, func_prefix = get_function_name(question, lang) |
|
|
|
try: |
|
start = code_block.lower().index(func_name.lower()) |
|
indent = 0 |
|
while start - indent >= 0 and code_block[start - indent-1] == ' ': |
|
indent += 1 |
|
|
|
try: |
|
end = code_block.rindex('\n' + ' '*indent + '}') |
|
except: |
|
end = len(code_block) |
|
except: |
|
start = 0 |
|
try: |
|
end = code_block.rindex('\n' + ' '*indent + '}') |
|
except: |
|
end = len(code_block) |
|
|
|
body = code_block[start:end] |
|
|
|
if lang_code.lower() in ['php', 'ts', 'js']: |
|
body += '\n' + ' '*indent + '}' |
|
|
|
generation = func_prefix + '\n' + body + '\n' |
|
example['generation'] = generation |
|
|
|
except Exception as ex: |
|
print("Failed to extract code block with error `{}`:\n>>> Task: {}\n>>> Output:\n{}".format( |
|
ex, task_id, output |
|
)) |
|
example['generation'] = example['prompt'] + '\n' + output |
|
|
|
return example |
|
|
|
def cleanup_code( |
|
code: str, |
|
language_type: str = None, |
|
dataset: str = None, |
|
issft: bool = False, |
|
stop_words = [] |
|
): |
|
""" |
|
Cleans up the generated code. |
|
""" |
|
|
|
if language_type.lower() == "python": |
|
if issft: |
|
code = _clean_python_code_for_sft(code) |
|
stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] |
|
code = _truncate_code_at_stopwords(code, stop_words) |
|
elif language_type.lower() == "ts": |
|
code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"]) |
|
else: |
|
code = _truncate_code_at_stopwords(code, stop_words) |
|
|
|
return code |
|
|
|
def _clean_python_code_for_sft(code): |
|
code = code.replace("\r", "") |
|
if "```python" in code: |
|
code_start_idx = code.index("```python") |
|
code = code[code_start_idx:].replace("```python", "").strip() |
|
end_idx = code.find("```") if "```" in code else len(code) |
|
code = code[:end_idx].strip() |
|
|
|
return code |
|
|
|
def _truncate_code_at_stopwords(code, stop_words): |
|
min_stop_idx = len(code) |
|
for stop_word in stop_words: |
|
stop_index = code.find(stop_word) |
|
if 0 <= stop_index < min_stop_idx: |
|
min_stop_idx = stop_index |
|
return code[:min_stop_idx] |