Spaces:
Running
Running
import os | |
import gc | |
import json | |
import random | |
import torch | |
import asyncio | |
import logging | |
import time | |
from typing import List, Dict, Any, Optional, Union, AsyncGenerator, Tuple | |
from fastapi import FastAPI, HTTPException, Query, Request, Depends, status | |
from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse | |
from fastapi.security import APIKeyHeader | |
from pydantic import BaseModel, Field, ValidationError, validator | |
from transformers import ( | |
AutoConfig, AutoModelForCausalLM, AutoTokenizer, | |
GenerationConfig, LogitsProcessorList, | |
MinLengthLogitsProcessor, MaxLengthCriteria, | |
StoppingCriteriaList, StoppingCriteria | |
) | |
import uvicorn | |
from concurrent.futures import ThreadPoolExecutor | |
import math | |
import torch.nn.functional as F | |
import copy | |
app = FastAPI(title="Chatbot Profesional Profesional API", version="1.0.0") | |
class StopSequenceCriteria(StoppingCriteria): | |
def __init__(self, stop_sequences: List[str], tokenizer: AutoTokenizer): | |
self.tokenizer = tokenizer | |
self.stop_sequences_text = [] | |
self.stop_sequence_ids = [] | |
for seq in stop_sequences: | |
if seq: | |
encoded_ids = tokenizer.encode(seq, add_special_tokens=False) | |
decoded_text = tokenizer.decode(encoded_ids, skip_special_tokens=True) | |
if decoded_text: | |
self.stop_sequences_text.append(decoded_text) | |
self.stop_sequence_ids.append(encoded_ids) | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
if not self.stop_sequence_ids: | |
return False | |
input_ids_list = input_ids[0].tolist() | |
for stop_seq_ids in self.stop_sequence_ids: | |
stop_len = len(stop_seq_ids) | |
if len(input_ids_list) >= stop_len: | |
if input_ids_list[-stop_len:] == stop_seq_ids: | |
return True | |
check_tail_len = 50 | |
if self.stop_sequence_ids: | |
max_stop_seq_token_len = max((len(seq) for seq in self.stop_sequence_ids), default=0) | |
check_tail_len = max(check_tail_len, max_stop_seq_token_len + 10) | |
tail_ids = input_ids_list[-min(check_tail_len, len(input_ids_list)):] | |
tail_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True) | |
for stop_seq_text in self.stop_sequences_text: | |
if stop_seq_text and stop_seq_text in tail_text: | |
return True | |
return False | |
logging.getLogger("uvicorn").handlers.clear() | |
logging.getLogger("uvicorn.error").handlers.clear() | |
logging.getLogger("uvicorn.access").handlers.clear() | |
logging.getLogger("uvicorn").propagate = False | |
logging.getLogger("uvicorn.error").propagate = False | |
logging.getLogger("uvicorn.access").propagate = False | |
logging.getLogger("uvicorn").setLevel(logging.CRITICAL) | |
logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL) | |
logging.getLogger("uvicorn.access").setLevel(logging.CRITICAL) | |
logging.getLogger("fastapi").setLevel(logging.CRITICAL) | |
logging.getLogger("transformers").setLevel(logging.CRITICAL) | |
logging.getLogger().handlers.clear() | |
logging.getLogger().addHandler(logging.NullHandler()) | |
DEFAULT_MODEL_NAME = "jnjj/gemma-3-1b-it-qat-int4-quantized-less-restricted-filtered-sf" | |
MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL_NAME) | |
SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", "Eres un asistente profesional y servicial.") | |
try: | |
MAX_CONTEXT_TOKENS = int(os.environ.get("MAX_CONTEXT_TOKENS", 1024)) | |
if MAX_CONTEXT_TOKENS <= 0: | |
raise ValueError("MAX_CONTEXT_TOKENS must be positive.") | |
except (ValueError, TypeError) as e: | |
MAX_CONTEXT_TOKENS = 1024 | |
try: | |
MAX_GENERATION_TOKENS = int(os.environ.get("MAX_GENERATION_TOKENS", 512)) | |
if MAX_GENERATION_TOKENS <= 0: | |
raise ValueError("MAX_GENERATION_TOKENS must be positive.") | |
except (ValueError, TypeError) as e: | |
MAX_GENERATION_TOKENS = 512 | |
try: | |
MAX_CONCURRENT_GENERATIONS = int(os.environ.get("MAX_CONCURRENT_GENERATIONS", 4)) | |
if MAX_CONCURRENT_GENERATIONS <= 0: | |
raise ValueError("MAX_CONCURRENT_GENERATIONS must be positive.") | |
except (ValueError, TypeError) as e: | |
MAX_CONCURRENT_GENERATIONS = 4 | |
TRUST_REMOTE_CODE = (MODEL_NAME == DEFAULT_MODEL_NAME) | |
TORCH_DTYPE = torch.float32 | |
API_KEY = os.environ.get("API_KEY") | |
global_model = None | |
global_tokenizer = None | |
global_tokens: Dict[str, Optional[int]] = {} | |
executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_GENERATIONS) | |
generation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_GENERATIONS) | |
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) | |
async def get_api_key(api_key: str = Depends(api_key_header)): | |
if API_KEY is None: | |
return | |
if api_key is None or api_key != API_KEY: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API Key") | |
return api_key | |
class GenerateRequest(BaseModel): | |
input_text: str = Field(...) | |
history: Optional[List[Dict[str, str]]] = Field(None) | |
stream: bool = Field(True) | |
temperature: float = Field(1.0, ge=0.0, le=2.0) | |
top_k: int = Field(50, ge=0) | |
top_p: float = Field(1.0, ge=0.0, le=1.0) | |
repetition_penalty: float = Field(1.0, ge=0.0) | |
frequency_penalty: float = Field(0.0, ge=0.0) | |
presence_penalty: float = Field(0.0, ge=0.0) | |
num_beams: int = Field(1, ge=1) | |
length_penalty: float = Field(1.0, ge=0.0) | |
no_repeat_ngram_size: int = Field(0, ge=0) | |
early_stopping: bool = Field(False) | |
do_sample: bool = Field(True) | |
use_mirostat: bool = Field(False) | |
mirostat_tau: float = Field(5.0, ge=0.0) | |
mirostat_eta: float = Field(0.1, ge=0.0) | |
max_new_tokens: int = Field(MAX_GENERATION_TOKENS, ge=1) | |
system_prompt: Optional[str] = Field(None) | |
seed: Optional[int] = Field(None) | |
stop_sequences: Optional[List[str]] = Field(None) | |
tokenize_only: bool = Field(False) | |
strip_trailing_whitespace: bool = Field(False) | |
remove_incomplete_sentences: bool = Field(False) | |
num_return_sequences: int = Field(1, ge=1, le=5) | |
bad_words_ids: Optional[List[List[int]]] = Field(None) | |
forced_bos_token_id: Optional[int] = Field(None) | |
forced_eos_token_id: Optional[int] = Field(None) | |
renormalize_logits: Optional[bool] = Field(None) | |
suppress_tokens: Optional[List[int]] = Field(None) | |
begin_suppress_tokens: Optional[List[int]] = Field(None) | |
end_suppress_tokens: Optional[List[int]] = Field(None) | |
encoder_no_repeat_ngram_size: int = Field(0, ge=0) | |
min_length: int = Field(0, ge=0) | |
max_length: Optional[int] = Field(None) | |
exponential_decay_length_penalty: Optional[Tuple[float, int, float]] = Field(None) | |
use_cache: bool = Field(True) | |
typical_p: float = Field(1.0, ge=0.0, le=1.0) | |
epsilon_cutoff: float = Field(0.0, ge=0.0) | |
eta_cutoff: float = Field(0.0, ge=0.0) | |
temperature_cutoff: Optional[float] = Field(None, ge=0.0) | |
encoder_repetition_penalty: float = Field(1.0, ge=0.0) | |
max_time: Optional[float] = Field(None, ge=0.0) | |
output_watermark: bool = Field(False) | |
remove_input_from_output: bool = Field(False) | |
eos_token_id_override: Optional[int] = Field(None) | |
pad_token_id_override: Optional[int] = Field(None) | |
bos_token_id_override: Optional[int] = Field(None) | |
repetition_penalty_range: Optional[int] = Field(None, ge=0) | |
diversity_penalty: float = Field(0.0, ge=0.0) | |
num_beam_groups: int = Field(1, ge=1) | |
return_dict_in_generate: bool = Field(False) | |
output_attentions: bool = Field(False) | |
output_hidden_states: bool = Field(False) | |
output_scores: bool = Field(False) | |
return_token_logprobs: bool = Field(False) | |
return_text_from_sequence: bool = Field(True) | |
length_normalization_factor: Optional[float] = Field(None) | |
min_new_tokens: int = Field(0, ge=0) | |
do_normalize_logits: bool = Field(False) | |
return_generation_inputs: bool = Field(False) | |
return_unused_generate_parameters: bool = Field(False) | |
use_fast_tokenizer: bool = Field(True) | |
model_kwargs: Optional[Dict[str, Any]] = Field(None) | |
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None) | |
return_only_text: bool = Field(False) | |
def validate_stop_sequences(cls, v): | |
if v is not None: | |
if not all(isinstance(seq, str) for seq in v): | |
raise ValueError('Each stop sequence must be a string') | |
return v | |
def validate_bad_words_ids(cls, v): | |
if v is not None: | |
if not all(isinstance(word_id_list, list) and all(isinstance(token_id, int) for token_id in word_id_list) for word_id_list in v): | |
raise ValueError('bad_words_ids must be a list of lists of integers') | |
return v | |
def validate_exponential_decay_length_penalty(cls, v): | |
if v is not None: | |
if not (isinstance(v, (list, tuple)) and len(v) == 3 and | |
isinstance(v[0], (int, float)) and v[0] > 0 and | |
isinstance(v[1], int) and v[1] >= 0 and | |
isinstance(v[2], (int, float))): | |
raise ValueError('exponential_decay_length_penalty must be a tuple/list of 3 numbers (decay_factor, start_index, threshold)') | |
return v | |
def format_conversation(input_text: str, history: Optional[List[Dict[str, str]]], system_prompt: Optional[str]) -> str: | |
full_history: List[Dict[str, str]] = [] | |
used_system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT | |
if not history or history[0].get("role") != "system" or history[0].get("content") != used_system_prompt: | |
full_history.append({"role": "system", "content": used_system_prompt}) | |
if history: | |
full_history.extend(history) | |
if not full_history or full_history[-1].get("role") != "user" or full_history[-1].get("content") != input_text: | |
full_history.append({"role": "user", "content": input_text}) | |
if global_tokenizer and hasattr(global_tokenizer, 'apply_chat_template') and global_tokenizer.chat_template: | |
try: | |
return global_tokenizer.apply_chat_template(full_history, tokenize=False, add_generation_prompt=True) | |
except Exception as e: | |
pass | |
formatted_text = "" | |
for i, message in enumerate(full_history): | |
if i == 0 and message["role"] == "system" and len(full_history) > 1 and full_history[1].get("role") == "system": | |
continue | |
if message["role"] == "system": | |
formatted_text += f"{message['content'].strip()}\n\n" | |
elif message["role"] == "user": | |
formatted_text += f"Usuario: {message['content'].strip()}\n" | |
elif message["role"] == "assistant": | |
formatted_text += f"Bot: {message['content'].strip()}\n" | |
if not formatted_text.endswith("Bot:"): | |
formatted_text += "Bot:" | |
return formatted_text.strip() | |
def truncate_encoded_ids(input_ids: torch.Tensor, max_length: int) -> torch.Tensor: | |
if input_ids.shape[-1] > max_length: | |
return input_ids[:, -max_length:] | |
return input_ids | |
def apply_seed(seed: Optional[int]): | |
if seed is not None: | |
torch.manual_seed(seed) | |
random.seed(seed) | |
def get_stopping_criteria(req: GenerateRequest, initial_ids: torch.Tensor, tokenizer: AutoTokenizer) -> StoppingCriteriaList: | |
criteria = StoppingCriteriaList() | |
max_len_from_req = None | |
if req.max_length is not None and req.max_length > 0: | |
max_len_from_req = req.max_length | |
elif req.max_new_tokens is not None and req.max_new_tokens > 0: | |
max_len_from_req = initial_ids.shape[-1] + req.max_new_tokens | |
else: | |
max_len_from_req = initial_ids.shape[-1] + MAX_GENERATION_TOKENS | |
if max_len_from_req is not None and max_len_from_req > 0: | |
criteria.append(MaxLengthCriteria(max_len_from_req)) | |
if req.min_length is not None and req.min_length > 0: | |
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id", -1) | |
criteria.append(MinLengthLogitsProcessor(initial_ids.shape[-1] + req.min_length, eos_token_id)) | |
if req.stop_sequences: | |
criteria.append(StopSequenceCriteria(req.stop_sequences, tokenizer)) | |
return criteria | |
def generate_next_token_sync( | |
input_ids, | |
past_key_values, | |
gen_cfg: GenerationConfig, | |
device: str | |
) -> Tuple[torch.Tensor, Any, Optional[float], Optional[torch.Tensor], Any, Any]: | |
with torch.no_grad(): | |
outputs = global_model( | |
input_ids, past_key_values=past_key_values, | |
use_cache=gen_cfg.use_cache, return_dict=True, | |
output_attentions=gen_cfg.output_attentions, | |
output_hidden_states=gen_cfg.output_hidden_states, | |
output_scores=gen_cfg.output_scores, | |
) | |
logits = outputs.logits[:, -1, :] | |
past = outputs.past_key_values | |
scores = outputs.scores if gen_cfg.output_scores else None | |
attentions = outputs.attentions if gen_cfg.output_attentions else None | |
hidden_states = outputs.hidden_states if gen_cfg.output_hidden_states else None | |
step_logits_for_criteria = logits.clone() | |
if gen_cfg.do_normalize_logits: | |
logits = F.log_softmax(logits, dim=-1) | |
if gen_cfg.do_sample: | |
if gen_cfg.use_mirostat_mode == 1 and hasattr(global_model, 'mirostat_sample_logits'): | |
token = global_model.mirostat_sample_logits( | |
logits=logits, | |
temperature=gen_cfg.temperature, | |
mirostat_tau=gen_cfg.mirostat_tau, | |
mirostat_eta=gen_cfg.mirostat_eta | |
).unsqueeze(0).to(device) | |
else: | |
logits = logits / gen_cfg.temperature | |
if gen_cfg.temperature_cutoff is not None and gen_cfg.temperature_cutoff > 0: | |
logits = torch.where(logits < gen_cfg.temperature_cutoff, torch.tensor(-float('Inf')).to(logits.device), logits) | |
if gen_cfg.top_k: | |
topk_values, topk_indices = torch.topk(logits, gen_cfg.top_k) | |
logits[logits < topk_values[:, -1]] = -float('Inf') | |
if gen_cfg.top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) | |
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) | |
sorted_indices_to_remove = cumulative_probs > gen_cfg.top_p | |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
sorted_indices_to_remove[..., 0] = False | |
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
logits[:, indices_to_remove] = -float('Inf') | |
if gen_cfg.typical_p < 1.0: | |
probs = torch.softmax(logits, dim=-1) | |
entropy = torch.distributions.Categorical(probs).entropy() | |
probs_sorted, indices_sorted = torch.sort(probs, dim=-1, descending=True) | |
cumsum_probs_sorted = torch.cumsum(probs_sorted, dim=-1) | |
mask = cumsum_probs_sorted < gen_cfg.typical_p * entropy.exp() | |
indices_to_remove = indices_sorted[~mask] | |
logits[:, indices_to_remove] = -float('Inf') | |
if gen_cfg.epsilon_cutoff is not None and gen_cfg.epsilon_cutoff > 0: | |
probs = torch.softmax(logits, dim=-1) | |
mask = probs < gen_cfg.epsilon_cutoff | |
logits[:, mask] = -float('Inf') | |
if gen_cfg.eta_cutoff is not None and gen_cfg.eta_cutoff > 0: | |
probs = torch.softmax(logits, dim=-1) | |
mask = probs > gen_cfg.eta_cutoff | |
logits[:, ~mask] = -float('Inf') | |
probs = torch.softmax(logits, dim=-1) | |
token = torch.multinomial(probs, 1) | |
else: | |
token = torch.argmax(logits, dim=-1, keepdim=True) | |
token_logprob = None | |
if gen_cfg.output_scores: | |
log_probs = F.log_softmax(step_logits_for_criteria, dim=-1) | |
if 0 <= token.squeeze().item() < log_probs.shape[-1]: | |
token_logprob = float(log_probs[:, token.squeeze()].item()) | |
else: | |
token_logprob = None | |
return token, past, token_logprob, step_logits_for_criteria, attentions, hidden_states | |
def post_process_text(text: str, strip_trailing_whitespace: bool, remove_incomplete_sentences: bool) -> str: | |
if strip_trailing_whitespace: | |
text = text.rstrip() | |
if remove_incomplete_sentences: | |
for terminator in ['.', '!', '?', '\n']: | |
last_terminator = text.rfind(terminator) | |
if last_terminator != -1: | |
text = text[:last_terminator + 1] | |
break | |
return text | |
async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> AsyncGenerator[Union[str, Tuple[Dict[str, Any], str]], None]: | |
past = None | |
generated_tokens_count = 0 | |
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id") | |
pad_token_id = req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id", eos_token_id) | |
stop_token_ids = {eos_token_id} if eos_token_id is not None else set() | |
if pad_token_id is not None and pad_token_id != eos_token_id: | |
stop_token_ids.add(pad_token_id) | |
current_ids = initial_ids | |
start_time = time.time() | |
total_ids_list = initial_ids.tolist()[0] | |
finish_reason = "unknown" | |
stopping_criteria = get_stopping_criteria(req, initial_ids, global_tokenizer) | |
last_step_logits = None | |
accumulated_text_for_processing = "" | |
try: | |
while True: | |
if generated_tokens_count >= req.max_new_tokens: | |
finish_reason = "max_new_tokens" | |
break | |
if req.max_time is not None and (time.time() - start_time) > req.max_time: | |
finish_reason = "time" | |
break | |
input_ids_sync = current_ids if past is None else token | |
token, past, token_logprob, step_logits, attentions, hidden_states = await asyncio.to_thread( | |
generate_next_token_sync, | |
input_ids_sync, | |
past, | |
gen_cfg, | |
device | |
) | |
last_step_logits = step_logits | |
generated_token_id = token[0].item() | |
total_ids_list.append(generated_token_id) | |
text = global_tokenizer.decode([generated_token_id], skip_special_tokens=True) | |
accumulated_text_for_processing += text | |
if req.return_only_text: | |
yield text | |
else: | |
chunk_payload: Dict[str, Any] = { | |
"type": "token", | |
"text": text, | |
"token_id": generated_token_id, | |
"generated_tokens_count": generated_tokens_count + 1, | |
} | |
if req.return_token_logprobs and token_logprob is not None: | |
chunk_payload["logprob"] = token_logprob | |
yield json.dumps(chunk_payload) + "\n" | |
if generated_token_id in stop_token_ids: | |
finish_reason = "eos_token" | |
break | |
current_full_ids_tensor = torch.tensor([total_ids_list], device=device) | |
if stopping_criteria(current_full_ids_tensor, step_logits): | |
finish_reason = "stopping_criteria" | |
current_len = len(total_ids_list) | |
initial_len = initial_ids.shape[-1] | |
max_len_crit_met = any(isinstance(c, MaxLengthCriteria) for c in stopping_criteria) and \ | |
( (req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens)) or | |
(req.max_length is not None and current_len >= req.max_length) ) | |
stop_seq_crit_met = any(isinstance(c, StopSequenceCriteria) for c in stopping_criteria) and req.stop_sequences and \ | |
any(seq in global_tokenizer.decode(total_ids_list[initial_len:], skip_special_tokens=True) for seq in req.stop_sequences) | |
if max_len_crit_met: | |
if req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens): | |
finish_reason = "max_new_tokens" | |
elif req.max_length is not None and current_len >= req.max_length: | |
finish_reason = "max_length" | |
if stop_seq_crit_met: | |
finish_reason = "stop_sequence" | |
break | |
current_ids = token | |
generated_tokens_count += 1 | |
final_text_raw = global_tokenizer.decode(total_ids_list[initial_ids.shape[-1]:], skip_special_tokens=True) | |
if req.stop_sequences and finish_reason == "stop_sequence": | |
for stop_seq in req.stop_sequences: | |
if stop_seq and stop_seq in final_text_raw: | |
final_text_raw = final_text_raw.split(stop_seq, 1)[0] | |
break | |
final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences) | |
if not req.return_only_text: | |
final_payload: Dict[str, Any] = { | |
"type": "done", | |
"total_prompt_tokens": initial_ids.shape[-1], | |
"total_generated_tokens": generated_tokens_count, | |
"total_sequence_tokens": len(total_ids_list), | |
"final_text": final_text_processed, | |
"finish_reason": finish_reason | |
} | |
yield json.dumps(final_payload) + "\n" | |
except Exception as e: | |
if req.return_only_text: | |
yield f"Error: {e}\n" | |
else: | |
error_payload = {"type": "error", "message": str(e)} | |
yield json.dumps(error_payload) + "\n" | |
finally: | |
await cleanup() | |
async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]: | |
try: | |
logits_processor_list = LogitsProcessorList() | |
stopping_criteria_list = get_stopping_criteria(req, initial_ids, global_tokenizer) | |
with torch.no_grad(): | |
out = global_model.generate( | |
input_ids=initial_ids, | |
generation_config=gen_cfg, | |
return_dict_in_generate=True, | |
output_scores=req.output_scores, | |
output_attentions=req.output_attentions, | |
output_hidden_states=req.output_hidden_states, | |
num_return_sequences=req.num_return_sequences, | |
bad_words_ids=req.bad_words_ids, | |
suppress_tokens=req.suppress_tokens, | |
begin_suppress_tokens=req.begin_suppress_tokens, | |
end_suppress_tokens=req.end_suppress_tokens, | |
logits_processor=logits_processor_list if logits_processor_list else None, | |
stopping_criteria=stopping_criteria_list if stopping_criteria_list else None, | |
) | |
generated_data = [] | |
for i in range(req.num_return_sequences): | |
if i >= len(out.sequences): | |
break | |
sequence = out.sequences[i] | |
start_index = initial_ids.shape[-1] | |
generated_ids_tensor = sequence[start_index:] | |
full_sequence_ids = sequence.tolist() | |
text = global_tokenizer.decode(generated_ids_tensor, skip_special_tokens=True) | |
if req.stop_sequences: | |
for stop_seq in req.stop_sequences: | |
if stop_seq and stop_seq in text: | |
text = text.split(stop_seq, 1)[0] | |
break | |
text = post_process_text(text, req.strip_trailing_whitespace, req.remove_incomplete_sentences) | |
finish_reason = "length" | |
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id") | |
if len(generated_ids_tensor) > 0 and eos_token_id is not None and generated_ids_tensor[-1] == eos_token_id: | |
finish_reason = "eos_token" | |
elif len(generated_ids_tensor) >= gen_cfg.max_new_tokens: | |
finish_reason = "max_new_tokens" | |
elif req.max_length is not None and len(full_sequence_ids) >= req.max_length: | |
finish_reason = "max_length" | |
elif hasattr(out, 'max_time_exceeded') and out.max_time_exceeded: | |
finish_reason = "time" | |
if req.stop_sequences and finish_reason == "length": | |
decoded_full_output = global_tokenizer.decode(full_sequence_ids, skip_special_tokens=True) | |
if any(seq in decoded_full_output for seq in req.stop_sequences): | |
finish_reason = "stop_sequence" | |
item_data: Dict[str, Any] = { | |
"text": text if req.return_text_from_sequence else None, | |
"token_ids": generated_ids_tensor.tolist(), | |
"generated_tokens_count": len(generated_ids_tensor), | |
"finish_reason": finish_reason | |
} | |
if not req.remove_input_from_output: | |
item_data["full_sequence_token_ids"] = full_sequence_ids | |
if req.output_scores and hasattr(out, 'scores') and out.scores is not None: | |
item_data["scores"] = "Scores output needs custom handling (complex structure)." | |
if req.return_token_logprobs: | |
item_data["token_logprobs"] = "Token logprobs require parsing scores output which is complex for batched/beamed generation." | |
if req.output_attentions and hasattr(out, 'attentions') and out.attentions is not None: | |
item_data["attentions"] = "Attentions output needs custom handling (too large)." | |
if req.output_hidden_states and hasattr(out, 'hidden_states') and out.hidden_states is not None: | |
item_data["hidden_states"] = "Hidden states output needs custom handling (too large)." | |
if hasattr(out, 'watermark') and out.watermark is not None: | |
item_data["watermark"] = out.watermark[i] if isinstance(out.watermark, list) and len(out.watermark) > i else out.watermark | |
generated_data.append(item_data) | |
response_payload: Dict[str, Any] = { | |
"prompt_tokens": initial_ids.shape[-1], | |
"generated_sequences": generated_data, | |
} | |
if req.num_return_sequences == 1 and generated_data: | |
response_payload["total_tokens"] = response_payload["prompt_tokens"] + generated_data[0]["generated_tokens_count"] | |
if req.return_dict_in_generate: | |
raw_out_dict = {} | |
for key in out.keys(): | |
if key not in ['sequences', 'scores', 'attentions', 'hidden_states', 'past_key_values', 'watermark', 'sequences_scores']: | |
value = out[key] | |
if isinstance(value, torch.Tensor): | |
raw_out_dict[key] = value.tolist() | |
else: | |
raw_out_dict[key] = value | |
response_payload["raw_generate_output"] = raw_out_dict | |
return response_payload | |
except Exception as e: | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}") | |
async def cleanup(): | |
gc.collect() | |
async def load_model(): | |
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, TORCH_DTYPE | |
torch.set_num_threads(max(1, os.cpu_count() // 2)) | |
torch.set_num_interop_threads(max(1, os.cpu_count() // 4)) | |
device = "cpu" | |
current_model_name = MODEL_NAME | |
current_trust_remote_code = TRUST_REMOTE_CODE | |
try: | |
config = AutoConfig.from_pretrained(current_model_name, trust_remote_code=current_trust_remote_code) | |
original_config = copy.deepcopy(config) | |
if hasattr(config, 'bos_token_id'): | |
config.bos_token_id = 1 | |
if hasattr(config, 'eos_token_id'): | |
config.eos_token_id = 2 | |
if hasattr(config, 'max_position_embeddings'): | |
config.max_position_embeddings = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'n_positions'): | |
config.n_positions = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'seq_len'): | |
config.seq_len = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'ctx'): | |
config.ctx = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'n_ctx'): | |
config.n_ctx = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'max_seq_length'): | |
config.max_seq_length = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'max_sequence_length'): | |
config.max_sequence_length = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'max_length'): | |
config.max_length = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'block_size'): | |
config.block_size = MAX_CONTEXT_TOKENS | |
if hasattr(config, 'use_cache'): | |
config.use_cache = False | |
if hasattr(config, 'tie_word_embeddings'): | |
config.tie_word_embeddings = True | |
if hasattr(config, 'output_attentions'): | |
config.output_attentions = False | |
if hasattr(config, 'output_hidden_states'): | |
config.output_hidden_states = False | |
if hasattr(config, 'use_cache'): | |
config.use_cache = False | |
tokenizer_kwargs = {"config": original_config, "trust_remote_code": current_trust_remote_code} | |
global_tokenizer = AutoTokenizer.from_pretrained(current_model_name, **tokenizer_kwargs) | |
model_kwargs = {"config": config, "torch_dtype": TORCH_DTYPE, "trust_remote_code": current_trust_remote_code} | |
global_model = AutoModelForCausalLM.from_pretrained(current_model_name, **model_kwargs) | |
global_model.to(device) | |
global_model.eval() | |
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id | |
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id | |
if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None: | |
global_tokens["pad_token_id"] = global_tokens["eos_token_id"] | |
if global_model.config.pad_token_id is None: | |
global_model.config.pad_token_id = global_tokens["pad_token_id"] | |
elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None: | |
pass | |
if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None: | |
global_model.config.pad_token_id = global_tokens["pad_token_id"] | |
except Exception as e: | |
global_model = None | |
global_tokenizer = None | |
global_tokens = {} | |
html_code = """ | |
<!DOCTYPE html> | |
<html lang="es"> | |
<head> | |
<meta charset="UTF-8" /> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> | |
<title>Chatbot Profesional</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; } | |
#chatbox { width: 100%; height: 400px; border: 1px solid #ccc; padding: 10px; overflow-y: scroll; margin-bottom: 10px; } | |
#user-input { width: calc(100% - 100px); padding: 8px; box-sizing: border-box;} | |
#send-btn { width: 90px; padding: 8px 0;} | |
#input-area { display: flex;} | |
</style> | |
</head> | |
<body> | |
<h1>Chatbot Profesional (POST API)</h1> | |
<div id="chatbox"></div> | |
<div id="input-area"> | |
<input type="text" id="user-input" placeholder="Escribe tu mensaje aquí..." autocomplete="off"/> | |
<button id="send-btn">Enviar</button> | |
</div> | |
<script> | |
const chatbox = document.getElementById('chatbox'); | |
const userInput = document.getElementById('user-input'); | |
const sendBtn = document.getElementById('send-btn'); | |
let conversationHistory = []; | |
const DEFAULT_SYSTEM_PROMPT = "Eres un asistente profesional y servicial."; | |
let currentSystemPrompt = DEFAULT_SYSTEM_PROMPT; | |
let botMessageElement = null; | |
function appendMessage(sender, text, isStreaming = false) { | |
let msg; | |
if (isStreaming && botMessageElement) { | |
botMessageElement.textContent += text; | |
} else { | |
msg = document.createElement('p'); | |
msg.innerHTML = `<strong>${sender}:</strong> `; | |
const textNode = document.createTextNode(text); | |
msg.appendChild(textNode); | |
chatbox.appendChild(msg); | |
if (sender === 'Bot' && isStreaming) { | |
botMessageElement = textNode; | |
} else { | |
botMessageElement = null; | |
} | |
} | |
chatbox.scrollTop = chatbox.scrollHeight; | |
} | |
function updateHistory(role, content) { | |
conversationHistory.push({ "role": role, "content": content }); | |
const maxHistorySize = 10; | |
if (conversationHistory.length > maxHistorySize * 2) { | |
conversationHistory = conversationHistory.slice(-(maxHistorySize * 2)); | |
} | |
} | |
async function sendMessage() { | |
const text = userInput.value; | |
if (!text) { | |
return; | |
} | |
appendMessage('Usuario', text); | |
updateHistory("user", text); | |
userInput.value = ''; | |
sendBtn.disabled = true; | |
botMessageElement = null; | |
const messagePayload = { | |
input_text: text, | |
history: conversationHistory, | |
system_prompt: currentSystemPrompt, | |
stream: true, | |
temperature: 1.0, | |
top_k: 50, | |
top_p: 1.0, | |
repetition_penalty: 1.0, | |
frequency_penalty: 0.0, | |
presence_penalty: 0.0, | |
num_beams: 1, | |
length_penalty: 1.0, | |
no_repeat_ngram_size: 0, | |
early_stopping: false, | |
do_sample: true, | |
use_mirostat: false, | |
mirostat_tau: 5.0, | |
mirostat_eta: 0.1, | |
max_new_tokens: 512, | |
num_return_sequences: 1, | |
return_token_logprobs: true | |
}; | |
try { | |
const response = await fetch('/generate', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json', | |
}, | |
body: JSON.stringify(messagePayload), | |
}); | |
if (!response.ok) { | |
const errorData = await response.json(); | |
throw new Error(`API Error: ${response.status} ${response.statusText} - ${errorData.detail || errorData.error}`); | |
} | |
const reader = response.body.getReader(); | |
const decoder = new TextDecoder(); | |
let buffer = ''; | |
let currentBotResponse = ""; | |
while (true) { | |
const { value, done } = await reader.read(); | |
if (done) break; | |
buffer += decoder.decode(value, { stream: true }); | |
const lines = buffer.split('\n'); | |
buffer = lines.pop(); | |
for (const line of lines) { | |
if (line.trim() === '') continue; | |
try { | |
const data = JSON.parse(line); | |
if (data.type === 'token') { | |
currentBotResponse += data.text; | |
appendMessage('Bot', data.text, true); | |
} else if (data.type === 'done') { | |
if (data.total_tokens !== undefined) { | |
appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`); | |
} | |
if (data.final_text !== undefined) { | |
updateHistory("assistant", data.final_text); | |
} else if (currentBotResponse) { | |
updateHistory("assistant", currentBotResponse); | |
} | |
} else if (data.type === 'error') { | |
appendMessage('Error', data.message); | |
currentBotResponse = ""; | |
} | |
} catch (e) { | |
appendMessage('Error', 'Failed to process stream.'); | |
currentBotResponse = ""; | |
reader.cancel(); | |
return; | |
} | |
} | |
} | |
if (buffer.trim() !== '') { | |
try { | |
const data = JSON.parse(buffer); | |
if (data.type === 'token') { | |
currentBotResponse += data.text; | |
appendMessage('Bot', data.text, true); | |
} else if (data.type === 'done') { | |
if (data.total_tokens !== undefined) { | |
appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`); | |
} | |
if (data.final_text !== undefined) { | |
updateHistory("assistant", data.final_text); | |
} else if (currentBotResponse) { | |
updateHistory("assistant", currentBotResponse); | |
} | |
} else if (data.type === 'error') { | |
appendMessage('Error', data.message); | |
currentBotResponse = ""; | |
} | |
} catch (e) { | |
appendMessage('Error', 'Failed to process remaining stream data.'); | |
currentBotResponse = ""; | |
} | |
} | |
if (currentBotResponse && !botMessageElement) { | |
updateHistory("assistant", currentBotResponse); | |
} | |
botMessageElement = null; | |
currentBotResponse = ""; | |
} catch (error) { | |
appendMessage('Error', error.message || 'An unknown error occurred.'); | |
botMessageElement = null; | |
currentBotResponse = ""; | |
} finally { | |
sendBtn.disabled = false; | |
} | |
} | |
sendBtn.onclick = sendMessage; | |
userInput.addEventListener('keypress', function(event) { | |
if (event.key === 'Enter') { | |
event.preventDefault(); | |
sendMessage(); | |
} | |
}); | |
</script> | |
</body> | |
</html> | |
""" | |
async def root(): | |
return HTMLResponse(content=html_code) | |
async def generate_endpoint(req: GenerateRequest): | |
if global_model is None or global_tokenizer is None: | |
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model is not loaded.") | |
device = "cpu" | |
apply_seed(req.seed) | |
try: | |
initial_prompt_text = format_conversation(req.input_text, req.history, req.system_prompt) | |
except Exception as e: | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error formatting conversation: {e}") | |
try: | |
tokenizer_encoding_kwargs = req.tokenizer_kwargs or {} | |
encoded = global_tokenizer(initial_prompt_text, return_tensors="pt", add_special_tokens=False, **tokenizer_encoding_kwargs).to(device) | |
initial_ids_before_trunc = encoded.input_ids | |
initial_prompt_tokens_count_before_trunc = initial_ids_before_trunc.shape[-1] | |
ids = truncate_encoded_ids(initial_ids_before_trunc, MAX_CONTEXT_TOKENS) | |
current_prompt_tokens_count = ids.shape[-1] | |
except Exception as e: | |
await cleanup() | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenizer encoding error: {e}") | |
if req.tokenize_only: | |
await cleanup() | |
return JSONResponse({ | |
"prompt_tokens_count": initial_prompt_tokens_count_before_trunc, | |
"max_context_tokens": MAX_CONTEXT_TOKENS, | |
"truncated": initial_prompt_tokens_count_before_trunc > MAX_CONTEXT_TOKENS, | |
"input_text_processed": initial_prompt_text, | |
"input_ids_truncated": ids.tolist()[0] | |
}) | |
total_capacity = MAX_CONTEXT_TOKENS + MAX_GENERATION_TOKENS | |
total_requested_seq_len = current_prompt_tokens_count + req.max_new_tokens | |
if not req.stream and total_requested_seq_len > total_capacity: | |
await cleanup() | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail=f"Requested sequence length ({total_requested_seq_len} tokens = {current_prompt_tokens_count} prompt + {req.max_new_tokens} new) exceeds model capacity ({total_capacity} tokens) and non-streaming is requested. Consider enabling streaming or reducing max_new_tokens." | |
) | |
async with generation_semaphore: | |
try: | |
gen_cfg = GenerationConfig( | |
temperature=req.temperature, | |
top_k=req.top_k, | |
top_p=req.top_p, | |
repetition_penalty=req.repetition_penalty, | |
frequency_penalty=req.frequency_penalty, | |
presence_penalty=req.presence_penalty, | |
num_beams=req.num_beams if not req.stream else 1, | |
length_penalty=req.length_penalty, | |
no_repeat_ngram_size=req.no_repeat_ngram_size, | |
early_stopping=req.early_stopping, | |
do_sample=req.do_sample, | |
use_mirostat_mode=1 if req.use_mirostat else 0, | |
mirostat_tau=req.mirostat_tau, | |
mirostat_eta=req.mirostat_eta, | |
max_new_tokens=req.max_new_tokens, | |
eos_token_id=req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id"), | |
pad_token_id=req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id"), | |
bos_token_id=req.bos_token_id_override if req.bos_token_id_override is not None else global_tokenizer.bos_token_id, | |
num_return_sequences=req.num_return_sequences if not req.stream else 1, | |
bad_words_ids=req.bad_words_ids, | |
forced_bos_token_id=req.forced_bos_token_id, | |
forced_eos_token_id=req.forced_eos_token_id, | |
renormalize_logits=req.renormalize_logits, | |
suppress_tokens=req.suppress_tokens, | |
begin_suppress_tokens=req.begin_suppress_tokens, | |
end_suppress_tokens=req.end_suppress_tokens, | |
encoder_no_repeat_ngram_size=req.encoder_no_repeat_ngram_size, | |
min_length=req.min_length, | |
max_length=req.max_length, | |
exponential_decay_length_penalty=req.exponential_decay_length_penalty, | |
use_cache=req.use_cache, | |
typical_p=req.typical_p, | |
epsilon_cutoff=req.epsilon_cutoff, | |
eta_cutoff=req.eta_cutoff, | |
temperature_cutoff=req.temperature_cutoff, | |
encoder_repetition_penalty=req.encoder_repetition_penalty, | |
max_time=req.max_time, | |
output_watermark=req.output_watermark, | |
diversity_penalty=req.diversity_penalty, | |
num_beam_groups=req.num_beam_groups if not req.stream else 1, | |
length_normalization_factor=req.length_normalization_factor, | |
min_new_tokens=req.min_new_tokens, | |
do_normalize_logits=req.do_normalize_logits, | |
output_scores=req.output_scores, | |
output_attentions=req.output_attentions, | |
output_hidden_states=req.output_hidden_states, | |
) | |
if req.stream: | |
gen_cfg.use_cache = True | |
gen_cfg.num_beams = 1 | |
gen_cfg.num_return_sequences = 1 | |
gen_cfg.num_beam_groups = 1 | |
return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="text/plain" if req.return_only_text else "application/json") | |
else: | |
response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device) | |
if req.return_only_text: | |
texts = [seq["text"] for seq in response_payload.get("generated_sequences", []) if seq.get("text") is not None] | |
if req.num_return_sequences == 1 and texts: | |
return PlainTextResponse(texts[0]) | |
else: | |
return JSONResponse(texts) | |
else: | |
return JSONResponse(response_payload) | |
except Exception as e: | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}") | |
finally: | |
await cleanup() | |
if __name__ == "__main__": | |
uvicorn.run( | |
app, host="0.0.0.0", port=7860, | |
log_level="critical", | |
access_log=False | |
) |