|
import sys |
|
import os |
|
|
|
prj_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
sys.path.append(prj_root_path) |
|
|
|
from code_interpreter.JuypyterClient import JupyterNotebook |
|
from code_interpreter.BaseCodeInterpreter import BaseCodeInterpreter |
|
from utils.const import * |
|
|
|
from typing import List, Literal, Optional, Tuple, TypedDict, Dict |
|
from colorama import init, Fore, Style |
|
import copy |
|
import re |
|
|
|
import torch |
|
import transformers |
|
from transformers import LlamaForCausalLM, LlamaTokenizer |
|
from peft import PeftModel |
|
|
|
|
|
sys.path.append(os.path.dirname(__file__)) |
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
from finetuning.conversation_template import msg_to_code_result_tok_temp |
|
from utils.special_tok_llama2 import ( |
|
B_CODE, |
|
E_CODE, |
|
B_RESULT, |
|
E_RESULT, |
|
B_INST, |
|
E_INST, |
|
B_SYS, |
|
E_SYS, |
|
DEFAULT_PAD_TOKEN, |
|
DEFAULT_BOS_TOKEN, |
|
DEFAULT_EOS_TOKEN, |
|
DEFAULT_UNK_TOKEN, |
|
IGNORE_INDEX, |
|
) |
|
|
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="transformers") |
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
|
|
|
|
|
class LlamaCodeInterpreter(BaseCodeInterpreter): |
|
def __init__( |
|
self, |
|
model_path: str, |
|
load_in_8bit: bool = False, |
|
load_in_4bit: bool = False, |
|
peft_model: Optional[str] = None, |
|
): |
|
|
|
self.tokenizer = LlamaTokenizer.from_pretrained( |
|
model_path, |
|
padding_side="right", |
|
use_fast=False, |
|
) |
|
|
|
|
|
special_tokens_dict = dict() |
|
if self.tokenizer.pad_token is None: |
|
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN |
|
if self.tokenizer.eos_token is None: |
|
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN |
|
if self.tokenizer.bos_token is None: |
|
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN |
|
if self.tokenizer.unk_token is None: |
|
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN |
|
|
|
self.tokenizer.add_special_tokens(special_tokens_dict) |
|
self.tokenizer.add_tokens( |
|
[B_CODE, E_CODE, B_RESULT, E_RESULT, B_INST, E_INST, B_SYS, E_SYS], |
|
special_tokens=True, |
|
) |
|
|
|
self.model = LlamaForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="auto", |
|
load_in_4bit=load_in_4bit, |
|
load_in_8bit=load_in_8bit, |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
|
|
|
if peft_model is not None: |
|
peft_model = PeftModel.from_pretrained(self.model, peft_model) |
|
|
|
self.model = self.model.eval() |
|
|
|
self.dialog = [ |
|
{ |
|
"role": "system", |
|
"content": CODE_INTERPRETER_SYSTEM_PROMPT + "\nUse code to answer", |
|
}, |
|
|
|
|
|
] |
|
|
|
self.nb = JupyterNotebook() |
|
self.MAX_CODE_OUTPUT_LENGTH = 3000 |
|
out = self.nb.add_and_run(TOOLS_CODE) |
|
print(out) |
|
|
|
def dialog_to_prompt(self, dialog: List[Dict]) -> str: |
|
full_str = msg_to_code_result_tok_temp(dialog) |
|
|
|
return full_str |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
prompt: str = "[INST]\n###User : hi\n###Assistant :", |
|
max_new_tokens=512, |
|
do_sample: bool = True, |
|
use_cache: bool = True, |
|
top_p: float = 0.95, |
|
temperature: float = 0.1, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.0, |
|
) -> str: |
|
|
|
|
|
input_prompt = copy.deepcopy(prompt) |
|
inputs = self.tokenizer([prompt], return_tensors="pt") |
|
input_tokens_shape = inputs["input_ids"].shape[-1] |
|
|
|
eos_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_EOS_TOKEN) |
|
e_code_token_id = self.tokenizer.convert_tokens_to_ids(E_CODE) |
|
|
|
output = self.model.generate( |
|
**inputs, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
top_p=top_p, |
|
temperature=temperature, |
|
use_cache=use_cache, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=[ |
|
eos_token_id, |
|
e_code_token_id, |
|
], |
|
)[0] |
|
|
|
generated_tokens = output[input_tokens_shape:] |
|
generated_text = self.tokenizer.decode(generated_tokens) |
|
|
|
return generated_text |
|
|
|
def extract_code_blocks(self, prompt: str) -> Tuple[bool, str]: |
|
pattern = re.escape(B_CODE) + r"(.*?)" + re.escape(E_CODE) |
|
matches = re.findall(pattern, prompt, re.DOTALL) |
|
|
|
if matches: |
|
|
|
return True, matches[-1].strip() |
|
else: |
|
return False, "" |
|
|
|
def clean_code_output(self, output: str) -> str: |
|
if self.MAX_CODE_OUTPUT_LENGTH < len(output): |
|
return ( |
|
output[: self.MAX_CODE_OUTPUT_LENGTH // 5] |
|
+ "...(skip)..." |
|
+ output[-self.MAX_CODE_OUTPUT_LENGTH // 5 :] |
|
) |
|
|
|
return output |
|
|
|
def chat(self, user_message: str, VERBOSE: bool = False, MAX_TRY=5): |
|
self.dialog.append({"role": "user", "content": user_message}) |
|
if VERBOSE: |
|
print( |
|
"###User : " + Fore.BLUE + Style.BRIGHT + user_message + Style.RESET_ALL |
|
) |
|
print("\n###Assistant : ") |
|
|
|
|
|
HAS_CODE = False |
|
INST_END_TOK_FLAG = False |
|
full_generated_text = "" |
|
prompt = self.dialog_to_prompt(dialog=self.dialog) |
|
start_prompt = copy.deepcopy(prompt) |
|
prompt = f"{prompt} {E_INST}" |
|
|
|
generated_text = self.generate(prompt) |
|
full_generated_text += generated_text |
|
HAS_CODE, generated_code_block = self.extract_code_blocks(generated_text) |
|
|
|
attempt = 1 |
|
while HAS_CODE: |
|
if attempt > MAX_TRY: |
|
break |
|
|
|
|
|
|
|
generated_code_block = generated_code_block.replace("<unk>_", "").replace( |
|
"<unk>", "" |
|
) |
|
|
|
code_block_output, error_flag = self.execute_code_and_return_output( |
|
f"{generated_code_block}" |
|
) |
|
code_block_output = self.clean_code_output(code_block_output) |
|
generated_text = ( |
|
f"{generated_text}\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
|
) |
|
full_generated_text += f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
|
|
|
first_code_block_pos = ( |
|
generated_text.find(generated_code_block) |
|
if generated_code_block |
|
else -1 |
|
) |
|
text_before_first_code_block = ( |
|
generated_text |
|
if first_code_block_pos == -1 |
|
else generated_text[:first_code_block_pos] |
|
) |
|
if VERBOSE: |
|
print(Fore.GREEN + text_before_first_code_block + Style.RESET_ALL) |
|
print(Fore.GREEN + generated_code_block + Style.RESET_ALL) |
|
print( |
|
Fore.YELLOW |
|
+ f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" |
|
+ Style.RESET_ALL |
|
) |
|
|
|
|
|
prompt = f"{prompt}{generated_text}" |
|
generated_text = self.generate(prompt) |
|
HAS_CODE, generated_code_block = self.extract_code_blocks(generated_text) |
|
|
|
full_generated_text += generated_text |
|
|
|
attempt += 1 |
|
|
|
if VERBOSE: |
|
print(Fore.GREEN + generated_text + Style.RESET_ALL) |
|
|
|
self.dialog.append( |
|
{ |
|
"role": "assistant", |
|
"content": full_generated_text.replace("<unk>_", "") |
|
.replace("<unk>", "") |
|
.replace("</s>", ""), |
|
} |
|
) |
|
|
|
return self.dialog[-1] |
|
|
|
|
|
if __name__ == "__main__": |
|
import random |
|
|
|
LLAMA2_MODEL_PATH = "./ckpt/llama-2-13b-chat" |
|
LLAMA2_MODEL_PATH = "meta-llama/Llama-2-70b-chat-hf" |
|
LLAMA2_FINETUNEED_PATH = "./output/llama-2-7b-chat-ci" |
|
|
|
interpreter = LlamaCodeInterpreter( |
|
model_path=LLAMA2_FINETUNEED_PATH, load_in_4bit=True |
|
) |
|
output = interpreter.chat( |
|
user_message=random.choice( |
|
[ |
|
|
|
|
|
|
|
"what is second largest city in japan?", |
|
|
|
] |
|
), |
|
VERBOSE=True, |
|
) |
|
|
|
while True: |
|
input_char = input("Press 'q' to quit the dialog: ") |
|
if input_char.lower() == "q": |
|
break |
|
|
|
else: |
|
output = interpreter.chat(user_message=input_char, VERBOSE=True) |
|
|