MTP-4 / app.py
teszenofficial's picture
Update app.py
9ebe5ea verified
# -*- coding: utf-8 -*-
"""
MTP 4 API - ASISTENTE AVANZADO
- Modelo: d_model=384, n_layers=6 (25M parámetros)
- Temperatura 0.4
- Sistema anti-alucinaciones
"""
import os
import sys
import torch
import json
import time
import gc
import re
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from huggingface_hub import snapshot_download
import uvicorn
import math
import torch.nn as nn
import torch.nn.functional as F
import sentencepiece as spm
from enum import Enum
from typing import Tuple, Optional
# ======================
# OPTIMIZACIONES
# ======================
if torch.cuda.is_available():
DEVICE = "cuda"
torch.backends.cudnn.benchmark = True
print("✅ GPU detectada. Modo rápido activado.")
else:
DEVICE = "cpu"
torch.set_num_threads(min(4, os.cpu_count() or 2))
torch.set_num_interop_threads(2)
print("⚠️ Usando CPU optimizado.")
torch.set_grad_enabled(False)
MODEL_REPO = "TeszenAI/MTP-4" # Cambia por tu repo
# ======================
# SISTEMA ANTI-ALUCINACIONES
# ======================
class AntiHallucination:
def __init__(self):
self.uncertainty_words = [
'no se', 'no lo se', 'no tengo idea', 'no estoy seguro',
'no puedo responder', 'no sé', 'desconozco'
]
self.empty_patterns = [
r'^[.,!?;:]+$', r'^[\s]+$', r'^[0-9]+$', r'^[a-zA-Z]{1,3}$',
]
self.repetition_patterns = [
r'(\b\w+\b)(?:\s+\1){5,}', r'(.)\1{10,}',
]
self.max_safe_chars = 500
def is_hallucinating(self, text: str) -> Tuple[bool, str]:
if not text:
return True, "Respuesta vacía"
if len(text) < 5:
return True, "Respuesta demasiado corta"
for pattern in self.empty_patterns:
if re.match(pattern, text):
return True, "Patrón vacío detectado"
for pattern in self.repetition_patterns:
if re.search(pattern, text):
return True, "Repetición excesiva"
words = text.lower().split()[:5]
for uw in self.uncertainty_words:
if uw in ' '.join(words):
return True, f"Expresa incertidumbre: '{uw}'"
if len(text) > self.max_safe_chars:
return True, "Respuesta demasiado larga"
return False, "OK"
def is_coherent(self, text: str, question: str) -> Tuple[bool, str]:
if not text or not question:
return True, "Sin datos suficientes"
text_lower = text.lower()
question_lower = question.lower()
question_words = set(re.findall(r'\b[a-záéíóúüñ]{3,}\b', question_lower))
if question_words:
matches = sum(1 for w in question_words if w in text_lower)
ratio = matches / len(question_words)
if len(question_words) >= 2 and ratio < 0.2:
return False, f"No responde a la pregunta"
return True, "OK"
# ======================
# SISTEMA DE PARADA INTELIGENTE
# ======================
class CompletionState(Enum):
INCOMPLETE = "incomplete"
COMPLETE = "complete"
SHOULD_STOP = "should_stop"
class IntelligentStopper:
def __init__(self):
self.completion_patterns = [r'\.\s*$', r'\!?\s*$', r'\?\s*$', r'\.\.\.\s*$']
self.continuation_patterns = [r'[,;:]\s*$', r' y $', r' o $', r' pero $', r' porque $']
self.completion_phrases = [
'gracias', 'saludos', 'adios', 'hasta luego',
'espero haberte ayudado', 'cualquier otra pregunta',
'que tengas un buen dia', 'nos vemos'
]
def analyze(self, text: str, min_length: int = 40) -> Tuple[CompletionState, str]:
if not text or len(text) < min_length:
return CompletionState.INCOMPLETE, "Demasiado corto"
text = text.strip()
for pattern in self.continuation_patterns:
if re.search(pattern, text, re.IGNORECASE):
return CompletionState.INCOMPLETE, "Indica continuación"
text_lower = text.lower()
for phrase in self.completion_phrases:
if phrase in text_lower[-80:]:
return CompletionState.COMPLETE, "Frase de finalización"
for pattern in self.completion_patterns:
if re.search(pattern, text):
if len(text) > min_length:
return CompletionState.COMPLETE, "Termina naturalmente"
if len(text) > 350:
return CompletionState.COMPLETE, "Longitud suficiente"
return CompletionState.INCOMPLETE, "Puede continuar"
# ======================
# ARQUITECTURA MTP 4 (IDÉNTICA AL ENTRENADOR)
# ======================
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
return self.weight * (x - x.mean(-1, keepdim=True)) / (x.std(-1, keepdim=True) + self.eps) + self.bias
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.2):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = math.sqrt(self.d_k)
def forward(self, x, mask=None):
b, s, _ = x.shape
Q = self.w_q(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(b, s, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = self.dropout(F.softmax(scores, dim=-1))
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(b, s, self.d_model)
return self.w_o(out)
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.2):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.gelu(self.linear1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.2):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ff = FeedForward(d_model, d_ff, dropout)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.dropout1(self.attn(self.norm1(x), mask))
x = x + self.dropout2(self.ff(self.norm2(x)))
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(0, max_len).float().unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class MTP4Model(nn.Module):
def __init__(self, vocab_size, d_model=384, n_heads=8, n_layers=6, d_ff=1536, dropout=0.2, max_len=512):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.max_len = max_len
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.norm = LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(dropout)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x):
seq_len = x.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(x.device)
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
for block in self.blocks:
x = block(x, mask)
return self.lm_head(self.norm(x))
@torch.no_grad()
def generate(self, input_ids, max_new=120, temperature=0.4, top_k=30, top_p=0.85,
repetition_penalty=1.3, stopper=None):
generated = input_ids
eos_id = 3
last_tokens = []
for step in range(max_new):
if generated.size(1) > self.max_len:
context = generated[:, -self.max_len:]
else:
context = generated
logits = self(context)
next_logits = logits[0, -1, :].clone() / temperature
if repetition_penalty != 1.0:
for token_id in set(generated[0].tolist()):
next_logits[token_id] /= repetition_penalty
if top_k > 0:
indices = next_logits < torch.topk(next_logits, top_k)[0][..., -1, None]
next_logits[indices] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = 0
indices = sorted_indices[remove]
next_logits[indices] = float('-inf')
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
last_tokens.append(next_token)
if len(last_tokens) > 6 and len(set(last_tokens)) <= 2:
break
if next_token == eos_id or next_token == 0:
break
generated = torch.cat([generated, torch.tensor([[next_token]], device=generated.device)], dim=1)
if stopper and step > 20 and step % 5 == 0:
gen_tokens = generated[0, len(input_ids):].tolist()
gen_tokens = [t for t in gen_tokens if t not in [0, 1, 2, 3]]
if gen_tokens:
current_text = sp.decode(gen_tokens) if 'sp' in dir() else ""
if current_text and len(current_text) > 50:
state, _ = stopper.analyze(current_text, min_length=40)
if state == CompletionState.COMPLETE:
break
return generated
# ======================
# LIMPIEZA DE RESPUESTAS
# ======================
def clean_response(text: str, question: str = "") -> str:
if not text:
return ""
words = text.split()
cleaned = []
last = ""
for w in words:
if w.lower() != last.lower():
cleaned.append(w)
last = w
text = " ".join(cleaned)
text = re.sub(r'\s+', ' ', text).strip()
greetings = ["hola", "buenos dias", "buenas tardes", "buenas noches", "hey"]
if question.lower().strip() in greetings:
if '.' in text:
text = text.split('.')[0] + '.'
elif len(text) > 100:
text = text[:100] + '...'
if len(text) > 400:
period = text[:400].rfind('.')
if period > 50:
text = text[:period+1]
else:
text = text[:400] + "..."
if len(text) < 3:
return "Lo siento, no pude generar una respuesta clara."
if text and text[0].islower():
text = text[0].upper() + text[1:]
return text
# ======================
# CARGA DEL MODELO
# ======================
print(f"📦 Descargando MTP 4 desde {MODEL_REPO}...")
repo_path = snapshot_download(repo_id=MODEL_REPO, repo_type="model", local_dir="mtp_repo")
config_path = os.path.join(repo_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
print(f"📋 Configuración encontrada:")
print(f" → d_model: {config.get('d_model', 'No especificado')}")
print(f" → n_layers: {config.get('n_layers', 'No especificado')}")
print(f" → n_heads: {config.get('n_heads', 'No especificado')}")
print(f" → d_ff: {config.get('d_ff', 'No especificado')}")
tokenizer_path = os.path.join(repo_path, "mtp_tokenizer.model")
sp = spm.SentencePieceProcessor()
sp.load(tokenizer_path)
VOCAB_SIZE = sp.get_piece_size()
config["vocab_size"] = VOCAB_SIZE
print(f"🧠 Inicializando MTP 4...")
print(f" → Vocabulario: {VOCAB_SIZE}")
print(f" → Dispositivo: {DEVICE.upper()}")
# Crear modelo con la configuración EXACTA del archivo
model = MTP4Model(**config)
model.to(DEVICE)
model_path = os.path.join(repo_path, "mtp_model.pt")
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=DEVICE)
# Usar strict=False para permitir pequeñas diferencias
model.load_state_dict(state_dict, strict=False)
print("✅ Pesos del modelo cargados")
model.eval()
param_count = sum(p.numel() for p in model.parameters())
print(f"✅ MTP 4 listo: {param_count:,} parámetros ({param_count/1e6:.2f}M)")
stopper = IntelligentStopper()
anti_hallucination = AntiHallucination()
# ======================
# API
# ======================
app = FastAPI(title="MTP 4 API", description="Asistente IA Avanzado", version="4.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
class PromptRequest(BaseModel):
text: str = Field(..., max_length=2000)
def build_prompt(user_input: str) -> str:
return f"### Instrucción:\n{user_input}\n\n### Respuesta:\n"
ACTIVE_REQUESTS = 0
@app.post("/generate")
async def generate(req: PromptRequest):
global ACTIVE_REQUESTS
ACTIVE_REQUESTS += 1
user_input = req.text.strip()
if not user_input:
ACTIVE_REQUESTS -= 1
return {"reply": ""}
full_prompt = build_prompt(user_input)
tokens = sp.encode(full_prompt)
if len(tokens) > 350:
tokens = tokens[:350]
input_ids = torch.tensor([tokens], device=DEVICE)
try:
start = time.time()
output_ids = model.generate(
input_ids,
max_new=100,
temperature=0.4,
top_k=30,
top_p=0.85,
repetition_penalty=1.3,
stopper=stopper
)
elapsed = time.time() - start
gen_tokens = output_ids[0, len(tokens):].tolist()
safe_tokens = [t for t in gen_tokens if 0 <= t < VOCAB_SIZE and t != 0]
response = sp.decode(safe_tokens).strip() if safe_tokens else ""
# Anti-alucinaciones
is_hallucinating, reason = anti_hallucination.is_hallucinating(response)
if is_hallucinating:
print(f"⚠️ Alucinación detectada: {reason}")
if safe_tokens and len(safe_tokens) > 20:
safe_tokens = safe_tokens[:20]
response = sp.decode(safe_tokens).strip()
is_hallucinating, _ = anti_hallucination.is_hallucinating(response)
if is_hallucinating:
response = ""
# Verificar coherencia
is_coherent, _ = anti_hallucination.is_coherent(response, user_input)
if not is_coherent and len(response) > 20:
first_sentence = response.split('.')[0] if '.' in response else response[:100]
if len(first_sentence) > 10:
response = first_sentence + '.'
response = clean_response(response, user_input)
if len(response) < 3:
response = "Lo siento, no pude generar una respuesta clara."
return {
"reply": response,
"tokens_generated": len(safe_tokens),
"time": round(elapsed, 2),
"model": "MTP-4"
}
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return {"reply": "Lo siento, ocurrió un error."}
finally:
ACTIVE_REQUESTS -= 1
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
@app.get("/health")
def health():
return {"status": "ok", "model": "MTP-4", "device": DEVICE}
@app.get("/info")
def info():
return {
"model": "MTP-4",
"version": "4.0",
"parameters": param_count,
"parameters_millions": round(param_count / 1e6, 2),
"device": DEVICE,
"vocab_size": VOCAB_SIZE
}
# ======================
# INTERFAZ WEB
# ======================
@app.get("/", response_class=HTMLResponse)
def chat_ui():
return """
<!DOCTYPE html>
<html lang="es">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MTP 4 - Asistente IA</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
background: linear-gradient(135deg, #0a0a0a 0%, #1a1a2e 100%);
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
height: 100vh;
display: flex;
flex-direction: column;
}
.header {
padding: 16px 20px;
background: rgba(0,0,0,0.5);
backdrop-filter: blur(10px);
border-bottom: 1px solid rgba(255,255,255,0.1);
}
.header h1 { color: white; font-size: 1.2rem; }
.header p { color: #888; font-size: 0.7rem; margin-top: 4px; }
.messages {
flex: 1;
overflow-y: auto;
padding: 20px;
display: flex;
flex-direction: column;
gap: 12px;
}
.message {
max-width: 80%;
padding: 10px 16px;
border-radius: 18px;
font-size: 0.9rem;
line-height: 1.4;
animation: fadeIn 0.2s ease;
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(8px); }
to { opacity: 1; transform: translateY(0); }
}
.user {
background: linear-gradient(135deg, #4a9eff, #3a7ecc);
color: white;
align-self: flex-end;
border-radius: 18px 4px 18px 18px;
}
.bot {
background: rgba(30, 31, 40, 0.95);
color: #e0e0e0;
align-self: flex-start;
border-radius: 4px 18px 18px 18px;
border: 1px solid rgba(255,255,255,0.05);
}
.input-area {
padding: 16px 20px;
background: rgba(0,0,0,0.5);
backdrop-filter: blur(10px);
border-top: 1px solid rgba(255,255,255,0.1);
display: flex;
gap: 12px;
}
input {
flex: 1;
padding: 12px 16px;
background: rgba(255,255,255,0.1);
border: 1px solid rgba(255,255,255,0.2);
border-radius: 24px;
color: white;
font-size: 0.9rem;
outline: none;
}
input:focus { border-color: #4a9eff; }
input::placeholder { color: #666; }
button {
padding: 12px 24px;
background: linear-gradient(135deg, #4a9eff, #3a7ecc);
border: none;
border-radius: 24px;
color: white;
font-weight: 500;
cursor: pointer;
}
button:hover { opacity: 0.9; }
button:disabled { opacity: 0.5; cursor: not-allowed; }
.typing {
background: rgba(30, 31, 40, 0.95);
padding: 10px 16px;
border-radius: 18px;
align-self: flex-start;
display: flex;
gap: 4px;
}
.typing span {
width: 8px;
height: 8px;
background: #888;
border-radius: 50%;
animation: bounce 1.4s infinite;
}
.typing span:nth-child(1) { animation-delay: -0.32s; }
.typing span:nth-child(2) { animation-delay: -0.16s; }
@keyframes bounce {
0%, 80%, 100% { transform: scale(0); }
40% { transform: scale(1); }
}
.suggestions {
display: flex;
gap: 8px;
padding: 10px 20px;
overflow-x: auto;
background: rgba(0,0,0,0.3);
}
.suggestion {
padding: 5px 12px;
background: rgba(255,255,255,0.1);
border-radius: 20px;
color: #aaa;
font-size: 0.75rem;
cursor: pointer;
white-space: nowrap;
}
.suggestion:hover {
background: linear-gradient(135deg, #4a9eff, #3a7ecc);
color: white;
}
.badge {
position: fixed;
bottom: 8px;
right: 8px;
font-size: 0.6rem;
color: #555;
background: rgba(0,0,0,0.5);
padding: 2px 8px;
border-radius: 12px;
}
@media (max-width: 600px) {
.message { max-width: 95%; }
.suggestions { display: none; }
}
</style>
</head>
<body>
<div class="header">
<h1>🤖 MTP 4 - Asistente IA</h1>
<p>✨ Temperatura 0.4 | Anti-alucinaciones | Respuestas precisas</p>
</div>
<div class="suggestions">
<div class="suggestion">Hola</div>
<div class="suggestion">¿Quién eres?</div>
<div class="suggestion">¿Qué puedes hacer?</div>
<div class="suggestion">Explícame la IA</div>
<div class="suggestion">Háblame de BTS</div>
<div class="suggestion">¿Qué es un agujero negro?</div>
<div class="suggestion">Dime un chiste</div>
<div class="suggestion">Adiós</div>
</div>
<div class="messages" id="messages">
<div class="message bot">✨ Hola, soy MTP 4. Estoy optimizado para dar respuestas coherentes y evitar alucinaciones. ¿En qué puedo ayudarte?</div>
</div>
<div class="input-area">
<input type="text" id="input" placeholder="Escribe tu pregunta..." autocomplete="off">
<button id="send">Enviar</button>
</div>
<div class="badge">⚡ MTP 4 | 🌡️ 0.4 | 🛡️ Anti-alucinaciones</div>
<script>
const messages = document.getElementById('messages');
const input = document.getElementById('input');
const sendBtn = document.getElementById('send');
let loading = false;
function addMessage(text, isUser, time = null) {
const div = document.createElement('div');
div.className = `message ${isUser ? 'user' : 'bot'}`;
div.innerHTML = `<div>${escapeHtml(text)}</div>${time ? `<div style="font-size:0.6rem;color:#666;margin-top:6px;">⚡ ${time}s</div>` : ''}`;
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
function showTyping() {
const div = document.createElement('div');
div.className = 'typing';
div.id = 'typing';
div.innerHTML = '<span></span><span></span><span></span>';
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function hideTyping() {
const el = document.getElementById('typing');
if (el) el.remove();
}
async function sendMessage() {
const text = input.value.trim();
if (!text || loading) return;
input.value = '';
addMessage(text, true);
loading = true;
sendBtn.disabled = true;
showTyping();
try {
const response = await fetch('/generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ text: text })
});
const data = await response.json();
hideTyping();
addMessage(data.reply, false, data.time);
} catch (error) {
hideTyping();
addMessage('⚠️ Error de conexión. Intenta de nuevo.', false);
} finally {
loading = false;
sendBtn.disabled = false;
input.focus();
}
}
input.addEventListener('keypress', (e) => { if (e.key === 'Enter') sendMessage(); });
sendBtn.addEventListener('click', sendMessage);
input.focus();
</script>
</body>
</html>
"""
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print("\n" + "=" * 60)
print(f"🚀 MTP 4 en http://0.0.0.0:{port}")
print(f"🌡️ Temperatura: 0.4 | 🔁 Repetition penalty: 1.3")
print("=" * 60)
uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")