kevinwang676's picture
Upload folder using huggingface_hub
4721aa1
import json
import ast
import astunparse
from transformers import PreTrainedTokenizer
from torch.utils.data import Dataset
from copy import deepcopy
from typing import Dict, List
# text constants
FUNCTION_CALL_NAME = 'tool_call'
FUNCTION_CALL_PREFIX = '```python\n'
FUNCTION_CALL_POSTFIX = '\n```'
TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You have access to the following tools:\n'
CONVERSATOIN_KEY = 'conversations'
TOOL_DESC_KEY = 'tools'
def format_function_call(function_name: str, parameters: Dict[str, str]):
function_name = ast.Name(id=function_name)
keywords = [
ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
for arg_name, arg_value in parameters.items()
]
func_call = ast.Call(func=function_name, args=[], keywords=keywords)
return astunparse.unparse(func_call).strip()
def format_conversation(item, tokenizer, conversation_key: str, tool_key: str):
conversations = deepcopy(item[conversation_key])
# Note: `loss_mask` here means whether *the prediction* of the token should take loss
tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0]
def _update(_tokens: List[int], value: int = 1):
value = int(value)
tokens.extend(_tokens)
loss_masks.extend([value] * len(_tokens))
# insert system prompt for tools
if tool_key in item:
conversations.insert(0,
{
"role": "system",
"content": TOOL_DEFINITION_PREFIX + json.dumps(item[tool_key], indent=4, ensure_ascii=False)
}
)
for idx, conv in enumerate(conversations):
loss = conv.get("loss", True)
if conv['role'] in {'system', 'user'}:
loss = False
if conv['role'] == 'tool':
# function call python code
value = FUNCTION_CALL_PREFIX + format_function_call(FUNCTION_CALL_NAME, conv["parameters"]) + FUNCTION_CALL_POSTFIX
text = tokenizer.build_single_message("assistant", conv["name"], value)
_update(text, loss)
# function call result
value = conv.get('observation', None)
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
text = tokenizer.build_single_message("observation", "", value)
_update(text, False)
else:
text = tokenizer.build_single_message(conv['role'], "", conv["content"])
_update(text, loss)
_update([tokenizer.eos_token_id], False)
assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
return tokens, loss_masks
def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer):
print("Sanity Check >>>>>>>>>>>>>")
for t, m in zip(tokens, target):
decoded = tokenizer.tokenizer.index_special_tokens[t] \
if t in tokenizer.tokenizer.index_special_tokens \
else tokenizer.decode([t])
print("%20s: %6d -> %6d" % (repr(decoded), t, m))
print("<<<<<<<<<<<<< Sanity Check")
assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
class MultiTurnDataset(Dataset):
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_seq_length: int):
super(MultiTurnDataset, self).__init__()
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i) -> dict:
data_item = self.data[i]
tokens, loss_masks = format_conversation(data_item, self.tokenizer, CONVERSATOIN_KEY, TOOL_DESC_KEY)
# labels are used inside the model
target_based_loss_mask = [False] + loss_masks[:-1]
labels = [(t if m else -100) for t, m in zip(tokens, target_based_loss_mask)]
tokens = tokens[:self.max_seq_length]
labels = labels[:self.max_seq_length]
tokens += [self.tokenizer.pad_token_id] * (self.max_seq_length - len(tokens))
labels += [-100] * (self.max_seq_length - len(labels))
assert len(tokens) == len(labels), f"length mismatch: {len(tokens)} vs {len(labels)}"
return {
"input_ids": tokens,
"labels": labels
}
class InputOutputDataset(Dataset):
def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
super(InputOutputDataset, self).__init__()
self.tokenizer = tokenizer
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.max_seq_length = max_source_length + max_target_length + 1
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i) -> dict:
data_item = self.data[i]
a_ids = self.tokenizer.encode(text=data_item['prompt'], add_special_tokens=True, truncation=True,
max_length=self.max_source_length)
b_ids = self.tokenizer.encode(text=data_item['response'], add_special_tokens=False, truncation=True,
max_length=self.max_target_length)
context_length = len(a_ids)
input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
pad_len = self.max_seq_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
labels = labels + [self.tokenizer.pad_token_id] * pad_len
labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"
return {
"input_ids": input_ids,
"labels": labels
}