Spaces:
Runtime error
Runtime error
codellama-CodeLlama-7b-hf
/
Llama2-Code-Interpreter-main
/OpenCodeInterpreter
/demo
/code_interpreter
/OpenCodeInterpreter.py
| import sys | |
| import os | |
| prj_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.append(prj_root_path) | |
| from code_interpreter.BaseCodeInterpreter import BaseCodeInterpreter | |
| from utils.const import * | |
| from typing import List, Tuple, Dict | |
| import re | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| sys.path.append(os.path.dirname(__file__)) | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="transformers") | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| class OpenCodeInterpreter(BaseCodeInterpreter): | |
| def __init__( | |
| self, | |
| model_path: str, | |
| load_in_8bit: bool = False, | |
| load_in_4bit: bool = False, | |
| ): | |
| # build tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, | |
| padding_side="right", | |
| trust_remote_code=True | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| load_in_4bit=load_in_4bit, | |
| load_in_8bit=load_in_8bit, | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True | |
| ) | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| self.model = self.model.eval() | |
| self.dialog = [] | |
| self.MAX_CODE_OUTPUT_LENGTH = 1000 | |
| def dialog_to_prompt(self, dialog: List[Dict]) -> str: | |
| full_str = self.tokenizer.apply_chat_template(dialog, tokenize=False) | |
| return full_str | |
| def extract_code_blocks(self, prompt: str) -> Tuple[bool, str]: | |
| pattern = re.escape("```python") + r"(.*?)" + re.escape("```") | |
| matches = re.findall(pattern, prompt, re.DOTALL) | |
| if matches: | |
| # Return the last matched code block | |
| return True, matches[-1].strip() | |
| else: | |
| return False, "" | |
| def clean_code_output(self, output: str) -> str: | |
| if self.MAX_CODE_OUTPUT_LENGTH < len(output): | |
| return ( | |
| output[: self.MAX_CODE_OUTPUT_LENGTH // 5] | |
| + "\n...(truncated due to length)...\n" | |
| + output[-self.MAX_CODE_OUTPUT_LENGTH // 5 :] | |
| ) | |
| return output | |