Spaces:
Paused
Paused
import os | |
from fastapi import FastAPI, Request | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
app = FastAPI() | |
# Retrieve the token from environment variable | |
hf_token = os.environ.get("HF_AUTH_TOKEN", None) | |
if hf_token is None: | |
print("WARNING: No HF_AUTH_TOKEN found in environment. " | |
"Make sure to set a Hugging Face token if the model is gated.") | |
# ------------------------------------------------------------------------- | |
# Update this to the Llama 2 Chat model you prefer. This example uses the | |
# 7B chat version. For larger models (13B, 70B), ensure you have enough RAM. | |
# ------------------------------------------------------------------------- | |
model_name = "meta-llama/Llama-2-7b-chat-hf" | |
# ------------------------------------------------------------------------- | |
# If the repo is gated, you may need: | |
# use_auth_token="YOUR_HF_TOKEN", | |
# trust_remote_code=True, | |
# or you can set environment variables in your HF Space to authenticate. | |
# ------------------------------------------------------------------------- | |
print(f"Loading model/tokenizer from: {model_name}") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=hf_token, | |
) | |
# ------------------------------------------------------------------------- | |
# If you had GPU available, you might do: | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.float16, | |
# device_map="auto", | |
# trust_remote_code=True | |
# ) | |
# But for CPU, we do a simpler load: | |
# ------------------------------------------------------------------------- | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_auth_token=hf_token, | |
) | |
# Choose device based on availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
model.to(device) | |
async def predict(request: Request): | |
""" | |
Endpoint for streaming responses from the Llama 2 chat model. | |
Expects JSON: { "prompt": "<Your prompt>" } | |
Returns a text/event-stream of tokens. | |
""" | |
data = await request.json() | |
prompt = data.get("prompt", "") | |
if not prompt: | |
return {"error": "Prompt is required"} | |
# Tokenize the input prompt | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len] | |
attention_mask = inputs.attention_mask # same shape | |
def token_generator(): | |
""" | |
A generator that yields tokens one by one for SSE streaming. | |
""" | |
nonlocal input_ids, attention_mask | |
# Basic generation hyperparameters | |
temperature = 0.7 | |
top_p = 0.9 | |
max_new_tokens = 30 # Increase for longer outputs | |
for _ in range(max_new_tokens): | |
with torch.no_grad(): | |
# 1) Forward pass: compute logits for next token | |
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
next_token_logits = outputs.logits[:, -1, :] | |
# 2) Apply temperature scaling | |
next_token_logits = next_token_logits / temperature | |
# 3) Convert logits -> probabilities | |
next_token_probs = F.softmax(next_token_logits, dim=-1) | |
# 4) Nucleus (top-p) sampling | |
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True) | |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
valid_indices = cumulative_probs <= top_p | |
filtered_probs = sorted_probs[valid_indices] | |
filtered_indices = sorted_indices[valid_indices] | |
# 5) If no tokens are valid under top_p, fallback to greedy | |
if len(filtered_probs) == 0: | |
next_token_id = torch.argmax(next_token_probs) | |
else: | |
sampled_id = torch.multinomial(filtered_probs, 1) | |
next_token_id = filtered_indices[sampled_id] | |
# 6) Ensure next_token_id has shape [batch_size, 1] | |
if next_token_id.dim() == 0: | |
# shape [] => [1] | |
next_token_id = next_token_id.unsqueeze(0) | |
# shape [1] => [1,1] | |
next_token_id = next_token_id.unsqueeze(-1) | |
# 7) Append token to input_ids | |
input_ids = torch.cat([input_ids, next_token_id], dim=-1) | |
# 8) Update attention_mask for the new token | |
new_mask = attention_mask.new_ones((attention_mask.size(0), 1)) | |
attention_mask = torch.cat([attention_mask, new_mask], dim=-1) | |
# 9) Decode and yield | |
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True) | |
yield token + " " | |
# 10) Stop if we encounter EOS | |
if tokenizer.eos_token_id is not None: | |
if next_token_id.squeeze().item() == tokenizer.eos_token_id: | |
break | |
# Return a StreamingResponse for SSE | |
return StreamingResponse(token_generator(), media_type="text/plain") | |