Bora-1 / inference.py
brandonbaek's picture
Upload inference.py
c4e9c17 verified
import torch
import tiktoken
import json
from typing import Dict, Optional
# Model Architecture Classes
class Config:
def __init__(self):
self.vocab_size = 100283
self.max_position_embeddings = 1024
self.hidden_size = 768
self.num_layers = 6
self.num_heads = 12
self.intermediate_size = 3072
self.dropout = 0.1
class AttentionHead(torch.nn.Module):
def __init__(self, config: Config):
super().__init__()
self.head_dim = config.hidden_size // config.num_heads
self.query = torch.nn.Linear(config.hidden_size, self.head_dim)
self.key = torch.nn.Linear(config.hidden_size, self.head_dim)
self.value = torch.nn.Linear(config.hidden_size, self.head_dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
Q = self.query(x)
K = self.key(x)
V = self.value(x)
scale = Q.size(-1) ** 0.5
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.nn.functional.softmax(scores, dim=-1)
return torch.matmul(attention, V)
class MultiHeadAttention(torch.nn.Module):
def __init__(self, config: Config):
super().__init__()
self.heads = torch.nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = torch.nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
heads = [head(x, mask) for head in self.heads]
multihead = torch.cat(heads, dim=-1)
return self.dropout(self.linear(multihead))
class TransformerBlock(torch.nn.Module):
def __init__(self, config: Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.norm1 = torch.nn.LayerNorm(config.hidden_size)
self.norm2 = torch.nn.LayerNorm(config.hidden_size)
self.feed_forward = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.intermediate_size),
torch.nn.GELU(),
torch.nn.Linear(config.intermediate_size, config.hidden_size),
torch.nn.Dropout(config.dropout)
)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
attended = self.attention(x, mask)
x = self.norm1(x + attended)
fed_forward = self.feed_forward(x)
return self.norm2(x + fed_forward)
class SmallLanguageModel(torch.nn.Module):
def __init__(self, config: Config):
super().__init__()
self.config = config
self.token_embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.transformer_blocks = torch.nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
self.dropout = torch.nn.Dropout(config.dropout)
self.ln_f = torch.nn.LayerNorm(config.hidden_size)
self.head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, torch.nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def get_causal_mask(self, size: int) -> torch.Tensor:
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return ~mask
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
b, t = input_ids.size()
positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device)
mask = self.get_causal_mask(t).to(input_ids.device)
token_embeddings = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(positions)
x = self.dropout(token_embeddings + position_embeddings)
for block in self.transformer_blocks:
x = block(x, mask)
x = self.ln_f(x)
logits = self.head(x)
return logits
# Text Generator Class
class TextGenerator:
def __init__(self, model, tokenizer):
self.model = model
self.model.eval()
self.tokenizer = tokenizer
@torch.no_grad()
def generate(
self,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9
) -> Dict[str, str]:
try:
input_ids = torch.tensor(self.tokenizer.encode(
prompt,
allowed_special={'<user>', '</user>', '<assistant>', '</assistant>', '<system>', '</system>'}
)).unsqueeze(0).to(device)
for _ in range(max_length):
if input_ids.size(1) > config.max_position_embeddings:
input_ids = input_ids[:, -config.max_position_embeddings:]
logits = self.model(input_ids)
next_token_logits = logits[:, -1, :] / temperature
if top_k > 0:
values, _ = torch.topk(next_token_logits, top_k)
min_value = values[:, -1].unsqueeze(-1)
next_token_logits = torch.where(
next_token_logits < min_value,
torch.tensor(float('-inf')).to(device),
next_token_logits
)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat((input_ids, next_token), dim=1)
generated_text = self.tokenizer.decode(input_ids[0].tolist())
return {
"status": "success",
"generated_text": generated_text,
"prompt": prompt,
"max_length": max_length,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p
}
except Exception as e:
return {
"status": "error",
"error_message": str(e),
"prompt": prompt
}
# Helper Function to Load Model and Tokenizer
def load_model_and_tokenizer(checkpoint_path: str) -> Tuple[SmallLanguageModel, tiktoken.Encoding]:
config = Config()
cl100k_base = tiktoken.get_encoding("cl100k_base")
tokenizer = tiktoken.Encoding(
name="cl100k_xml",
pat_str=cl100k_base._pat_str,
mergeable_ranks=cl100k_base._mergeable_ranks,
special_tokens={
**cl100k_base._special_tokens,
"<user>": 100277, "</user>": 100278,
"<assistant>": 100279, "</assistant>": 100280,
"<system>": 100281, "</system>": 100282
}
)
config.vocab_size = tokenizer.n_vocab
model = SmallLanguageModel(config)
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
return model, tokenizer
# Main Function for Inference
def generate(
checkpoint_path: str,
prompt: str,
max_length: int = 100,
temperature: float = 0.7,
top_k: int = 50,
top_p: float = 0.9
) -> Dict[str, str]:
global device, config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(checkpoint_path)
# Generate text
generator = TextGenerator(model, tokenizer)
result = generator.generate(
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
return result