Spaces:
Paused
Paused
| import os | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import StreamingResponse | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import torch.nn.functional as F | |
| app = FastAPI() | |
| # Retrieve the token from environment variable | |
| hf_token = os.environ.get("HF_AUTH_TOKEN", None) | |
| if hf_token is None: | |
| print("WARNING: No HF_AUTH_TOKEN found in environment. " | |
| "Make sure to set a Hugging Face token if the model is gated.") | |
| # ------------------------------------------------------------------------- | |
| # Update this to the Llama 2 Chat model you prefer. This example uses the | |
| # 7B chat version. For larger models (13B, 70B), ensure you have enough RAM. | |
| # ------------------------------------------------------------------------- | |
| model_name = "meta-llama/Llama-2-7b-chat-hf" | |
| # ------------------------------------------------------------------------- | |
| # If the repo is gated, you may need: | |
| # use_auth_token="YOUR_HF_TOKEN", | |
| # trust_remote_code=True, | |
| # or you can set environment variables in your HF Space to authenticate. | |
| # ------------------------------------------------------------------------- | |
| print(f"Loading model/tokenizer from: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_auth_token=hf_token, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # If you had GPU available, you might do: | |
| # model = AutoModelForCausalLM.from_pretrained( | |
| # model_name, | |
| # torch_dtype=torch.float16, | |
| # device_map="auto", | |
| # trust_remote_code=True | |
| # ) | |
| # But for CPU, we do a simpler load: | |
| # ------------------------------------------------------------------------- | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_auth_token=hf_token, | |
| ) | |
| # Choose device based on availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| model.to(device) | |
| async def predict(request: Request): | |
| """ | |
| Endpoint for streaming responses from the Llama 2 chat model. | |
| Expects JSON: { "prompt": "<Your prompt>" } | |
| Returns a text/event-stream of tokens. | |
| """ | |
| data = await request.json() | |
| prompt = data.get("prompt", "") | |
| if not prompt: | |
| return {"error": "Prompt is required"} | |
| # Tokenize the input prompt | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len] | |
| attention_mask = inputs.attention_mask # same shape | |
| def token_generator(): | |
| """ | |
| A generator that yields tokens one by one for SSE streaming. | |
| """ | |
| nonlocal input_ids, attention_mask | |
| # Basic generation hyperparameters | |
| temperature = 0.7 | |
| top_p = 0.9 | |
| max_new_tokens = 30 # Increase for longer outputs | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(): | |
| # 1) Forward pass: compute logits for next token | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| next_token_logits = outputs.logits[:, -1, :] | |
| # 2) Apply temperature scaling | |
| next_token_logits = next_token_logits / temperature | |
| # 3) Convert logits -> probabilities | |
| next_token_probs = F.softmax(next_token_logits, dim=-1) | |
| # 4) Nucleus (top-p) sampling | |
| sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True) | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| valid_indices = cumulative_probs <= top_p | |
| filtered_probs = sorted_probs[valid_indices] | |
| filtered_indices = sorted_indices[valid_indices] | |
| # 5) If no tokens are valid under top_p, fallback to greedy | |
| if len(filtered_probs) == 0: | |
| next_token_id = torch.argmax(next_token_probs) | |
| else: | |
| sampled_id = torch.multinomial(filtered_probs, 1) | |
| next_token_id = filtered_indices[sampled_id] | |
| # 6) Ensure next_token_id has shape [batch_size, 1] | |
| if next_token_id.dim() == 0: | |
| # shape [] => [1] | |
| next_token_id = next_token_id.unsqueeze(0) | |
| # shape [1] => [1,1] | |
| next_token_id = next_token_id.unsqueeze(-1) | |
| # 7) Append token to input_ids | |
| input_ids = torch.cat([input_ids, next_token_id], dim=-1) | |
| # 8) Update attention_mask for the new token | |
| new_mask = attention_mask.new_ones((attention_mask.size(0), 1)) | |
| attention_mask = torch.cat([attention_mask, new_mask], dim=-1) | |
| # 9) Decode and yield | |
| token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True) | |
| yield token + " " | |
| # 10) Stop if we encounter EOS | |
| if tokenizer.eos_token_id is not None: | |
| if next_token_id.squeeze().item() == tokenizer.eos_token_id: | |
| break | |
| # Return a StreamingResponse for SSE | |
| return StreamingResponse(token_generator(), media_type="text/plain") | |