|
import json
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import re
|
|
from collections import Counter
|
|
|
|
|
|
class ChatTokenizer:
|
|
def __init__(self, vocab_size=1000):
|
|
self.vocab_size = vocab_size
|
|
self.token2id = {}
|
|
self.id2token = {}
|
|
self.bpe_ranks = {}
|
|
|
|
def tokenize(self, text):
|
|
words = re.findall(r"\w+|\S", text.lower())
|
|
return [' '.join(list(word)) + ' </w>' for word in words]
|
|
|
|
def get_stats(self, tokens):
|
|
pairs = Counter()
|
|
for token in tokens:
|
|
symbols = token.split()
|
|
for i in range(len(symbols) - 1):
|
|
pairs[(symbols[i], symbols[i+1])] += 1
|
|
return pairs
|
|
|
|
def merge_pairs(self, tokens, pair):
|
|
pattern = re.escape(' '.join(pair))
|
|
replacement = ''.join(pair)
|
|
return [re.sub(rf'\b{pattern}\b', replacement, token) for token in tokens]
|
|
|
|
def train(self, texts):
|
|
tokens = []
|
|
for text in texts:
|
|
tokens.extend(self.tokenize(text))
|
|
vocab = Counter(tokens)
|
|
|
|
for _ in range(self.vocab_size):
|
|
pairs = self.get_stats(vocab)
|
|
if not pairs:
|
|
break
|
|
best = pairs.most_common(1)[0][0]
|
|
vocab = Counter(self.merge_pairs(vocab.elements(), best))
|
|
self.bpe_ranks[best] = _
|
|
|
|
final_tokens = set()
|
|
for token in vocab:
|
|
final_tokens.update(token.split())
|
|
final_tokens.update(["<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"])
|
|
self.token2id = {tok: i for i, tok in enumerate(sorted(final_tokens))}
|
|
self.id2token = {i: tok for tok, i in self.token2id.items()}
|
|
|
|
def encode(self, text):
|
|
tokenized = self.tokenize(text)
|
|
for pair, _ in sorted(self.bpe_ranks.items(), key=lambda x: x[1]):
|
|
tokenized = self.merge_pairs(tokenized, pair)
|
|
ids = []
|
|
for token in tokenized:
|
|
for part in token.split():
|
|
ids.append(self.token2id.get(part, self.token2id["<UNK>"]))
|
|
ids.append(self.token2id["<END>"])
|
|
return ids
|
|
|
|
def decode(self, token_ids):
|
|
tokens = [self.id2token.get(tid, "<UNK>") for tid in token_ids]
|
|
sentence = ""
|
|
for tok in tokens:
|
|
if tok == "<END>":
|
|
break
|
|
elif tok == "</w>":
|
|
sentence += " "
|
|
elif tok in {"<PAD>", "<UNK>"}:
|
|
continue
|
|
else:
|
|
sentence += tok
|
|
return sentence.strip()
|
|
|
|
def save(self, path):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump({
|
|
"token2id": self.token2id,
|
|
"bpe_ranks": {f"{a} {b}": r for (a, b), r in self.bpe_ranks.items()}
|
|
}, f)
|
|
|
|
def load(self, path):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
self.token2id = {k: int(v) for k, v in data["token2id"].items()}
|
|
self.id2token = {v: k for k, v in self.token2id.items()}
|
|
self.bpe_ranks = {tuple(k.split()): v for k, v in data["bpe_ranks"].items()}
|
|
|
|
def __len__(self):
|
|
return len(self.token2id)
|
|
|
|
@property
|
|
def stoi(self):
|
|
return self.token2id
|
|
|
|
@property
|
|
def itos(self):
|
|
return self.id2token
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self.token2id)
|
|
|
|
|
|
class ChatDataset(Dataset):
|
|
def __init__(self, file_path, tokenizer, block_size=64):
|
|
self.samples = []
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
data = json.loads(line)
|
|
text = data["text"].strip()
|
|
|
|
|
|
if not text.lower().startswith("^user:"):
|
|
text = "^User: " + text
|
|
if "MiniGPT:" not in text:
|
|
text += "\nMiniGPT:"
|
|
|
|
tokens = tokenizer.encode(text)
|
|
|
|
for i in range(0, len(tokens) - block_size):
|
|
x = tokens[i:i + block_size]
|
|
y = tokens[i + 1:i + block_size + 1]
|
|
self.samples.append((x, y))
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
x, y = self.samples[idx]
|
|
return torch.tensor(x), torch.tensor(y)
|
|
|
|
|
|
|
|
|
|
class ChatDataset(Dataset):
|
|
def __init__(self, file_path, tokenizer, block_size=64):
|
|
self.samples = []
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
data = json.loads(line)
|
|
text = data["text"].strip()
|
|
|
|
|
|
if not text.lower().startswith("^user:"):
|
|
text = "^User: " + text
|
|
if "MiniGPT:" not in text:
|
|
text += "\nMiniGPT:"
|
|
|
|
tokens = tokenizer.encode(text) + [tokenizer.stoi["<END>"]]
|
|
|
|
for i in range(0, len(tokens) - block_size):
|
|
x = tokens[i:i + block_size]
|
|
y = tokens[i + 1:i + block_size + 1]
|
|
self.samples.append((x, y))
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
x, y = self.samples[idx]
|
|
return torch.tensor(x), torch.tensor(y) |