| | import torch |
| | import torch.nn as nn |
| | import unicodedata |
| | import os |
| | import gradio as gr |
| | from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel |
| | from tokenizers import decoders |
| |
|
| | |
| | class IsaiConfig(PretrainedConfig): |
| | model_type = "isai" |
| | def __init__(self, vocab_size=32000, hidden_size=1024, intermediate_size=2816, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs): |
| | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.num_key_value_heads = num_key_value_heads |
| | self.max_position_embeddings = max_position_embeddings |
| | self.rms_norm_eps = rms_norm_eps |
| |
|
| | class IsaiRMSNorm(nn.Module): |
| | def __init__(self, hidden_size, eps=1e-6): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| | self.variance_epsilon = eps |
| | def forward(self, hidden_states): |
| | input_dtype = hidden_states.dtype |
| | hidden_states = hidden_states.to(torch.float32) |
| | variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| | return self.weight * hidden_states.to(input_dtype) |
| |
|
| | class IsaiForCausalLM(PreTrainedModel): |
| | config_class = IsaiConfig |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.model = nn.ModuleDict({ |
| | "embed_tokens": nn.Embedding(config.vocab_size, config.hidden_size), |
| | "layers": nn.ModuleList([nn.ModuleDict({ |
| | "input_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps), |
| | "post_attention_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps), |
| | "self_attn": nn.Linear(config.hidden_size, config.hidden_size, bias=False), |
| | "mlp": nn.ModuleDict({ |
| | "gate_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False), |
| | "up_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False), |
| | "down_proj": nn.Linear(config.intermediate_size, config.hidden_size, bias=False), |
| | }) |
| | }) for _ in range(config.num_hidden_layers)]), |
| | "norm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | }) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.post_init() |
| |
|
| | def forward(self, input_ids=None, **kwargs): |
| | hidden_states = self.model.embed_tokens(input_ids) |
| | for layer in self.model.layers: |
| | h = layer.input_layernorm(hidden_states) |
| | hidden_states = hidden_states + layer.self_attn(h) |
| | h = layer.post_attention_layernorm(hidden_states) |
| | hidden_states = hidden_states + layer.mlp.down_proj(nn.functional.silu(layer.mlp.gate_proj(h)) * layer.mlp.up_proj(h)) |
| | logits = self.lm_head(self.model.norm(hidden_states)) |
| | return logits |
| |
|
| | |
| | model_dir = "models/isai-v4.2" |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir) |
| | tokenizer._tokenizer.decoder = decoders.ByteLevel() |
| |
|
| | config = IsaiConfig.from_pretrained(model_dir) |
| | model = IsaiForCausalLM(config).to(device) |
| |
|
| | |
| | weights_path = os.path.join(model_dir, "model.safetensors") |
| | if os.path.exists(weights_path): |
| | from safetensors.torch import load_file |
| | model.load_state_dict(load_file(weights_path)) |
| | else: |
| | model.load_state_dict(torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location=device)) |
| | model.eval() |
| |
|
| | |
| | def predict(message, history): |
| | |
| | decomposed_input = unicodedata.normalize('NFD', message) |
| | input_ids = tokenizer.encode(decomposed_input, return_tensors="pt").to(device) |
| | |
| | current_ids = input_ids |
| | max_new_tokens = 50 |
| |
|
| | |
| | for _ in range(max_new_tokens): |
| | with torch.no_grad(): |
| | logits = model(current_ids) |
| | next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) |
| | current_ids = torch.cat([current_ids, next_token], dim=-1) |
| | if next_token.item() == tokenizer.eos_token_id: |
| | break |
| |
|
| | |
| | |
| | generated_tokens = current_ids[0][input_ids.shape[1]:] |
| | raw_response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| | final_response = unicodedata.normalize('NFC', raw_response) |
| | |
| | return final_response |
| |
|
| | |
| | demo = gr.ChatInterface( |
| | fn=predict, |
| | title="isai-v4.2 Jaso-Level Chat", |
| | description="μμ λ¨μ(NFD)λ‘ μν΅νλ μ΄μν μΌμ λν λͺ¨λΈμ
λλ€. μ
λ ₯μ μλμΌλ‘ λΆν΄λκ³ μΆλ ₯μ λ€μ νκΈλ‘ μ‘°ν©λ©λλ€.", |
| | examples=["μλ
? λ°κ°μ.", "μ€λ λ μ¨κ° μ΄λ?", "λμ μ΄λ¦μ λμΌ?"] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |