| | import os |
| | import sys |
| | import pickle |
| | import torch |
| | import gradio as gr |
| | from huggingface_hub import snapshot_download |
| |
|
| | |
| | |
| | |
| | REPO_ID = "teszenofficial/MTP7" |
| | MODEL_FILE = "mtp_mini.pkl" |
| | TOKENIZER_FILE = "mtp_tokenizer.model" |
| | LOCAL_DIR = "mtptz_repo" |
| |
|
| | |
| | |
| | |
| |
|
| | def load_resources(): |
| | print(f"📦 Descargando modelo desde {REPO_ID}...") |
| | |
| | |
| | repo_path = snapshot_download( |
| | repo_id=REPO_ID, |
| | local_dir=LOCAL_DIR |
| | ) |
| | print(f"✅ Modelo descargado en: {repo_path}") |
| |
|
| | |
| | sys.path.insert(0, repo_path) |
| |
|
| | try: |
| | |
| | from model import MTPMiniModel |
| | from tokenizer import MTPTokenizer |
| | except ImportError as e: |
| | print(f"❌ ERROR: No se pudieron importar 'model' o 'tokenizer'.") |
| | print(f" Asegúrate de que subiste 'model.py' y 'tokenizer.py' al repo '{REPO_ID}'.") |
| | raise e |
| |
|
| | |
| | model_path = os.path.join(repo_path, MODEL_FILE) |
| | tokenizer_path = os.path.join(repo_path, TOKENIZER_FILE) |
| |
|
| | |
| | if not os.path.exists(model_path): |
| | raise FileNotFoundError(f"No se encontró {MODEL_FILE} en el repo.") |
| | if not os.path.exists(tokenizer_path): |
| | raise FileNotFoundError(f"No se encontró {TOKENIZER_FILE} en el repo.") |
| |
|
| | |
| | tokenizer = MTPTokenizer(tokenizer_path) |
| | print(f"✅ Tokenizer cargado. Vocab size: {tokenizer.vocab_size()}") |
| |
|
| | |
| | print(f"🧠 Cargando tensores...") |
| | with open(model_path, 'rb') as f: |
| | model_data = pickle.load(f) |
| |
|
| | config = model_data['config'] |
| | state_dict = model_data['model_state_dict'] |
| | vocab_size = model_data['vocab_size'] |
| | |
| | |
| | use_swiglu = config['model'].get('use_swiglu', False) |
| | model = MTPMiniModel( |
| | vocab_size=vocab_size, |
| | d_model=config['model']['d_model'], |
| | n_layers=config['model']['n_layers'], |
| | n_heads=config['model']['n_heads'], |
| | d_ff=config['model']['d_ff'], |
| | max_seq_len=config['model']['max_seq_len'], |
| | dropout=0.0, |
| | use_swiglu=use_swiglu |
| | ) |
| | |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | |
| | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model.to(DEVICE) |
| | print(f"✅ Modelo cargado en {DEVICE}") |
| | |
| | return model, tokenizer, DEVICE |
| |
|
| | |
| | model, tokenizer, DEVICE = load_resources() |
| |
|
| | |
| | |
| | |
| | def generate_response(message, history, temperature, max_tokens, top_p): |
| | |
| | |
| | prompt = f"### Instrucción:\n{message}\n\n### Respuesta:\n" |
| | |
| | |
| | tokens = [tokenizer.bos_id()] + tokenizer.encode(prompt) |
| | input_ids = torch.tensor([tokens], device=DEVICE) |
| | |
| | |
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | input_ids, |
| | max_new_tokens=int(max_tokens), |
| | temperature=float(temperature), |
| | top_k=40, |
| | top_p=float(top_p), |
| | repetition_penalty=1.15, |
| | min_length=10, |
| | eos_token_id=tokenizer.eos_id() |
| | ) |
| | |
| | |
| | gen_tokens = output_ids[0, len(tokens):].tolist() |
| | safe_tokens = [] |
| | for t in gen_tokens: |
| | if 0 <= t < tokenizer.vocab_size() and t != tokenizer.eos_id(): |
| | safe_tokens.append(t) |
| | elif t == tokenizer.eos_id(): |
| | break |
| | |
| | response = tokenizer.decode(safe_tokens).strip() |
| | |
| | |
| | if "### Instrucción:" in response: |
| | response = response.split("### Instrucción:")[0].strip() |
| | |
| | return response |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# 🤖 MTP-7 Chat (Demo)") |
| | gr.Markdown(f"Modelo cargado desde `teszenofficial/MTP7` en **{DEVICE}**.") |
| | |
| | chat_interface = gr.ChatInterface( |
| | fn=generate_response, |
| | additional_inputs=[ |
| | gr.Slider(0.1, 2.0, value=0.7, label="Temperatura (Creatividad)"), |
| | gr.Slider(50, 300, value=150, label="Máximos Tokens"), |
| | gr.Slider(0.1, 1.0, value=0.92, label="Top-p (Nucleus)"), |
| | ], |
| | examples=[ |
| | ["¿Cuál es la capital de Francia?", 0.7, 150, 0.92], |
| | ["Explica qué es la relatividad.", 0.7, 150, 0.92] |
| | ], |
| | cache_examples=False |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |