|
import torch |
|
import tiktoken |
|
import json |
|
from typing import Dict, Optional |
|
|
|
|
|
class Config: |
|
def __init__(self): |
|
self.vocab_size = 100283 |
|
self.max_position_embeddings = 1024 |
|
self.hidden_size = 768 |
|
self.num_layers = 6 |
|
self.num_heads = 12 |
|
self.intermediate_size = 3072 |
|
self.dropout = 0.1 |
|
|
|
class AttentionHead(torch.nn.Module): |
|
def __init__(self, config: Config): |
|
super().__init__() |
|
self.head_dim = config.hidden_size // config.num_heads |
|
self.query = torch.nn.Linear(config.hidden_size, self.head_dim) |
|
self.key = torch.nn.Linear(config.hidden_size, self.head_dim) |
|
self.value = torch.nn.Linear(config.hidden_size, self.head_dim) |
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
Q = self.query(x) |
|
K = self.key(x) |
|
V = self.value(x) |
|
|
|
scale = Q.size(-1) ** 0.5 |
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale |
|
|
|
if mask is not None: |
|
scores = scores.masked_fill(mask == 0, float('-inf')) |
|
|
|
attention = torch.nn.functional.softmax(scores, dim=-1) |
|
return torch.matmul(attention, V) |
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
def __init__(self, config: Config): |
|
super().__init__() |
|
self.heads = torch.nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)]) |
|
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size) |
|
self.dropout = torch.nn.Dropout(config.dropout) |
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
heads = [head(x, mask) for head in self.heads] |
|
multihead = torch.cat(heads, dim=-1) |
|
return self.dropout(self.linear(multihead)) |
|
|
|
class TransformerBlock(torch.nn.Module): |
|
def __init__(self, config: Config): |
|
super().__init__() |
|
self.attention = MultiHeadAttention(config) |
|
self.norm1 = torch.nn.LayerNorm(config.hidden_size) |
|
self.norm2 = torch.nn.LayerNorm(config.hidden_size) |
|
self.feed_forward = torch.nn.Sequential( |
|
torch.nn.Linear(config.hidden_size, config.intermediate_size), |
|
torch.nn.GELU(), |
|
torch.nn.Linear(config.intermediate_size, config.hidden_size), |
|
torch.nn.Dropout(config.dropout) |
|
) |
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
attended = self.attention(x, mask) |
|
x = self.norm1(x + attended) |
|
fed_forward = self.feed_forward(x) |
|
return self.norm2(x + fed_forward) |
|
|
|
class SmallLanguageModel(torch.nn.Module): |
|
def __init__(self, config: Config): |
|
super().__init__() |
|
self.config = config |
|
self.token_embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size) |
|
self.position_embedding = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
self.transformer_blocks = torch.nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) |
|
self.dropout = torch.nn.Dropout(config.dropout) |
|
self.ln_f = torch.nn.LayerNorm(config.hidden_size) |
|
self.head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if isinstance(module, torch.nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, torch.nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def get_causal_mask(self, size: int) -> torch.Tensor: |
|
mask = torch.triu(torch.ones(size, size), diagonal=1).bool() |
|
return ~mask |
|
|
|
def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
b, t = input_ids.size() |
|
positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device) |
|
mask = self.get_causal_mask(t).to(input_ids.device) |
|
token_embeddings = self.token_embedding(input_ids) |
|
position_embeddings = self.position_embedding(positions) |
|
x = self.dropout(token_embeddings + position_embeddings) |
|
for block in self.transformer_blocks: |
|
x = block(x, mask) |
|
x = self.ln_f(x) |
|
logits = self.head(x) |
|
return logits |
|
|
|
|
|
class TextGenerator: |
|
def __init__(self, model, tokenizer): |
|
self.model = model |
|
self.model.eval() |
|
self.tokenizer = tokenizer |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
prompt: str, |
|
max_length: int = 100, |
|
temperature: float = 0.7, |
|
top_k: int = 50, |
|
top_p: float = 0.9 |
|
) -> Dict[str, str]: |
|
try: |
|
input_ids = torch.tensor(self.tokenizer.encode( |
|
prompt, |
|
allowed_special={'<user>', '</user>', '<assistant>', '</assistant>', '<system>', '</system>'} |
|
)).unsqueeze(0).to(device) |
|
|
|
for _ in range(max_length): |
|
if input_ids.size(1) > config.max_position_embeddings: |
|
input_ids = input_ids[:, -config.max_position_embeddings:] |
|
|
|
logits = self.model(input_ids) |
|
next_token_logits = logits[:, -1, :] / temperature |
|
|
|
if top_k > 0: |
|
values, _ = torch.topk(next_token_logits, top_k) |
|
min_value = values[:, -1].unsqueeze(-1) |
|
next_token_logits = torch.where( |
|
next_token_logits < min_value, |
|
torch.tensor(float('-inf')).to(device), |
|
next_token_logits |
|
) |
|
|
|
if top_p < 1.0: |
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf')) |
|
|
|
probs = torch.nn.functional.softmax(next_token_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
input_ids = torch.cat((input_ids, next_token), dim=1) |
|
|
|
generated_text = self.tokenizer.decode(input_ids[0].tolist()) |
|
return { |
|
"status": "success", |
|
"generated_text": generated_text, |
|
"prompt": prompt, |
|
"max_length": max_length, |
|
"temperature": temperature, |
|
"top_k": top_k, |
|
"top_p": top_p |
|
} |
|
|
|
except Exception as e: |
|
return { |
|
"status": "error", |
|
"error_message": str(e), |
|
"prompt": prompt |
|
} |
|
|
|
|
|
def load_model_and_tokenizer(checkpoint_path: str) -> Tuple[SmallLanguageModel, tiktoken.Encoding]: |
|
config = Config() |
|
cl100k_base = tiktoken.get_encoding("cl100k_base") |
|
tokenizer = tiktoken.Encoding( |
|
name="cl100k_xml", |
|
pat_str=cl100k_base._pat_str, |
|
mergeable_ranks=cl100k_base._mergeable_ranks, |
|
special_tokens={ |
|
**cl100k_base._special_tokens, |
|
"<user>": 100277, "</user>": 100278, |
|
"<assistant>": 100279, "</assistant>": 100280, |
|
"<system>": 100281, "</system>": 100282 |
|
} |
|
) |
|
config.vocab_size = tokenizer.n_vocab |
|
|
|
model = SmallLanguageModel(config) |
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.to(device) |
|
|
|
return model, tokenizer |
|
|
|
|
|
def generate( |
|
checkpoint_path: str, |
|
prompt: str, |
|
max_length: int = 100, |
|
temperature: float = 0.7, |
|
top_k: int = 50, |
|
top_p: float = 0.9 |
|
) -> Dict[str, str]: |
|
global device, config |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(checkpoint_path) |
|
|
|
|
|
generator = TextGenerator(model, tokenizer) |
|
result = generator.generate( |
|
prompt=prompt, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p |
|
) |
|
|
|
return result |