File size: 5,718 Bytes
4de3b20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import json
import torch
from torch.utils.data import Dataset
import re
from collections import Counter
class ChatTokenizer:
def __init__(self, vocab_size=1000):
self.vocab_size = vocab_size
self.token2id = {}
self.id2token = {}
self.bpe_ranks = {}
def tokenize(self, text):
words = re.findall(r"\w+|\S", text.lower())
return [' '.join(list(word)) + ' </w>' for word in words]
def get_stats(self, tokens):
pairs = Counter()
for token in tokens:
symbols = token.split()
for i in range(len(symbols) - 1):
pairs[(symbols[i], symbols[i+1])] += 1
return pairs
def merge_pairs(self, tokens, pair):
pattern = re.escape(' '.join(pair))
replacement = ''.join(pair)
return [re.sub(rf'\b{pattern}\b', replacement, token) for token in tokens]
def train(self, texts):
tokens = []
for text in texts:
tokens.extend(self.tokenize(text))
vocab = Counter(tokens)
for _ in range(self.vocab_size):
pairs = self.get_stats(vocab)
if not pairs:
break
best = pairs.most_common(1)[0][0]
vocab = Counter(self.merge_pairs(vocab.elements(), best))
self.bpe_ranks[best] = _
final_tokens = set()
for token in vocab:
final_tokens.update(token.split())
final_tokens.update(["<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"])
self.token2id = {tok: i for i, tok in enumerate(sorted(final_tokens))}
self.id2token = {i: tok for tok, i in self.token2id.items()}
def encode(self, text):
tokenized = self.tokenize(text)
for pair, _ in sorted(self.bpe_ranks.items(), key=lambda x: x[1]):
tokenized = self.merge_pairs(tokenized, pair)
ids = []
for token in tokenized:
for part in token.split():
ids.append(self.token2id.get(part, self.token2id["<UNK>"]))
ids.append(self.token2id["<END>"])
return ids
def decode(self, token_ids):
tokens = [self.id2token.get(tid, "<UNK>") for tid in token_ids]
sentence = ""
for tok in tokens:
if tok == "<END>":
break
elif tok == "</w>":
sentence += " "
elif tok in {"<PAD>", "<UNK>"}:
continue
else:
sentence += tok
return sentence.strip()
def save(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump({
"token2id": self.token2id,
"bpe_ranks": {f"{a} {b}": r for (a, b), r in self.bpe_ranks.items()}
}, f)
def load(self, path):
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self.token2id = {k: int(v) for k, v in data["token2id"].items()}
self.id2token = {v: k for k, v in self.token2id.items()}
self.bpe_ranks = {tuple(k.split()): v for k, v in data["bpe_ranks"].items()}
def __len__(self):
return len(self.token2id)
@property
def stoi(self):
return self.token2id
@property
def itos(self):
return self.id2token
@property
def vocab_size(self):
return len(self.token2id)
class ChatDataset(Dataset):
def __init__(self, file_path, tokenizer, block_size=64):
self.samples = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
text = data["text"].strip()
# Wrap in format: ^User: ... MiniGPT: ...
if not text.lower().startswith("^user:"):
text = "^User: " + text
if "MiniGPT:" not in text:
text += "\nMiniGPT:"
tokens = tokenizer.encode(text)
for i in range(0, len(tokens) - block_size):
x = tokens[i:i + block_size]
y = tokens[i + 1:i + block_size + 1]
self.samples.append((x, y))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x, y = self.samples[idx]
return torch.tensor(x), torch.tensor(y)
class ChatDataset(Dataset):
def __init__(self, file_path, tokenizer, block_size=64):
self.samples = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
text = data["text"].strip()
# Wrap in format: ^User: ... MiniGPT: ...
if not text.lower().startswith("^user:"):
text = "^User: " + text
if "MiniGPT:" not in text:
text += "\nMiniGPT:"
tokens = tokenizer.encode(text) + [tokenizer.stoi["<END>"]]
for i in range(0, len(tokens) - block_size):
x = tokens[i:i + block_size]
y = tokens[i + 1:i + block_size + 1]
self.samples.append((x, y))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x, y = self.samples[idx]
return torch.tensor(x), torch.tensor(y) |