light / app.py
aixk's picture
Update app.py
53663e6 verified
import torch
import torch.nn as nn
import unicodedata
import os
import gradio as gr
from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel
from tokenizers import decoders
# 1. Re-define the Architecture Classes (identical to the training/test phase)
class IsaiConfig(PretrainedConfig):
model_type = "isai"
def __init__(self, vocab_size=32000, hidden_size=1024, intermediate_size=2816, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
class IsaiRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class IsaiForCausalLM(PreTrainedModel):
config_class = IsaiConfig
def __init__(self, config):
super().__init__(config)
self.model = nn.ModuleDict({
"embed_tokens": nn.Embedding(config.vocab_size, config.hidden_size),
"layers": nn.ModuleList([nn.ModuleDict({
"input_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
"post_attention_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps),
"self_attn": nn.Linear(config.hidden_size, config.hidden_size, bias=False),
"mlp": nn.ModuleDict({
"gate_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
"up_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
"down_proj": nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
})
}) for _ in range(config.num_hidden_layers)]),
"norm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
})
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids=None, **kwargs):
hidden_states = self.model.embed_tokens(input_ids)
for layer in self.model.layers:
h = layer.input_layernorm(hidden_states)
hidden_states = hidden_states + layer.self_attn(h)
h = layer.post_attention_layernorm(hidden_states)
hidden_states = hidden_states + layer.mlp.down_proj(nn.functional.silu(layer.mlp.gate_proj(h)) * layer.mlp.up_proj(h))
logits = self.lm_head(self.model.norm(hidden_states))
return logits
# 2. Load Model and Tokenizer
model_dir = "models/isai-v4.2"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir)
tokenizer._tokenizer.decoder = decoders.ByteLevel() # Critical for jaso restoration
config = IsaiConfig.from_pretrained(model_dir)
model = IsaiForCausalLM(config).to(device)
# Prioritize safetensors
weights_path = os.path.join(model_dir, "model.safetensors")
if os.path.exists(weights_path):
from safetensors.torch import load_file
model.load_state_dict(load_file(weights_path))
else:
model.load_state_dict(torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location=device))
model.eval()
# 3. Define the Prediction Logic with Jaso Processing
def predict(message, history):
# A. NFD Decomposition (Input)
decomposed_input = unicodedata.normalize('NFD', message)
input_ids = tokenizer.encode(decomposed_input, return_tensors="pt").to(device)
current_ids = input_ids
max_new_tokens = 50
# B. Generate tokens
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(current_ids)
next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0)
current_ids = torch.cat([current_ids, next_token], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
# C. Decode and NFC Recomposition (Output)
# Only decode the generated part
generated_tokens = current_ids[0][input_ids.shape[1]:]
raw_response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
final_response = unicodedata.normalize('NFC', raw_response)
return final_response
# 4. Create and Launch Gradio Interface
demo = gr.ChatInterface(
fn=predict,
title="isai-v4.2 Jaso-Level Chat",
description="μžμ†Œ λ‹¨μœ„(NFD)둜 μ†Œν†΅ν•˜λŠ” μ΄ˆμ†Œν˜• 일상 λŒ€ν™” λͺ¨λΈμž…λ‹ˆλ‹€. μž…λ ₯은 μžλ™μœΌλ‘œ λΆ„ν•΄λ˜κ³  좜λ ₯은 λ‹€μ‹œ ν•œκΈ€λ‘œ μ‘°ν•©λ©λ‹ˆλ‹€.",
examples=["μ•ˆλ…•? λ°˜κ°€μ›Œ.", "였늘 날씨가 μ–΄λ•Œ?", "λ„ˆμ˜ 이름은 뭐야?"]
)
if __name__ == "__main__":
demo.launch(share=True)