custom-api / app.py
DataChem's picture
Update app.py
977cc0a verified
raw
history blame
5.33 kB
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
app = FastAPI()
# Retrieve the token from environment variable
hf_token = os.environ.get("HF_AUTH_TOKEN", None)
if hf_token is None:
print("WARNING: No HF_AUTH_TOKEN found in environment. "
"Make sure to set a Hugging Face token if the model is gated.")
# -------------------------------------------------------------------------
# Update this to the Llama 2 Chat model you prefer. This example uses the
# 7B chat version. For larger models (13B, 70B), ensure you have enough RAM.
# -------------------------------------------------------------------------
model_name = "meta-llama/Llama-2-7b-chat-hf"
# -------------------------------------------------------------------------
# If the repo is gated, you may need:
# use_auth_token="YOUR_HF_TOKEN",
# trust_remote_code=True,
# or you can set environment variables in your HF Space to authenticate.
# -------------------------------------------------------------------------
print(f"Loading model/tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_auth_token=hf_token,
)
# -------------------------------------------------------------------------
# If you had GPU available, you might do:
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# torch_dtype=torch.float16,
# device_map="auto",
# trust_remote_code=True
# )
# But for CPU, we do a simpler load:
# -------------------------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
use_auth_token=hf_token,
)
# Choose device based on availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)
@app.post("/predict")
async def predict(request: Request):
"""
Endpoint for streaming responses from the Llama 2 chat model.
Expects JSON: { "prompt": "<Your prompt>" }
Returns a text/event-stream of tokens.
"""
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len]
attention_mask = inputs.attention_mask # same shape
def token_generator():
"""
A generator that yields tokens one by one for SSE streaming.
"""
nonlocal input_ids, attention_mask
# Basic generation hyperparameters
temperature = 0.7
top_p = 0.9
max_new_tokens = 30 # Increase for longer outputs
for _ in range(max_new_tokens):
with torch.no_grad():
# 1) Forward pass: compute logits for next token
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
next_token_logits = outputs.logits[:, -1, :]
# 2) Apply temperature scaling
next_token_logits = next_token_logits / temperature
# 3) Convert logits -> probabilities
next_token_probs = F.softmax(next_token_logits, dim=-1)
# 4) Nucleus (top-p) sampling
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
valid_indices = cumulative_probs <= top_p
filtered_probs = sorted_probs[valid_indices]
filtered_indices = sorted_indices[valid_indices]
# 5) If no tokens are valid under top_p, fallback to greedy
if len(filtered_probs) == 0:
next_token_id = torch.argmax(next_token_probs)
else:
sampled_id = torch.multinomial(filtered_probs, 1)
next_token_id = filtered_indices[sampled_id]
# 6) Ensure next_token_id has shape [batch_size, 1]
if next_token_id.dim() == 0:
# shape [] => [1]
next_token_id = next_token_id.unsqueeze(0)
# shape [1] => [1,1]
next_token_id = next_token_id.unsqueeze(-1)
# 7) Append token to input_ids
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# 8) Update attention_mask for the new token
new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
# 9) Decode and yield
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
yield token + " "
# 10) Stop if we encounter EOS
if tokenizer.eos_token_id is not None:
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
break
# Return a StreamingResponse for SSE
return StreamingResponse(token_generator(), media_type="text/plain")