custom-api / app.py
DataChem's picture
Update app.py
d847de1 verified
raw
history blame
4.27 kB
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
app = FastAPI()
# -------------------------------------------------------------------------
# Since Falcon 7B Instruct is not gated, you do NOT need an HF token.
# We omit any 'use_auth_token' parameter.
# -------------------------------------------------------------------------
model_name = "Sao10K/L3-8B-Stheno-v3.2"
print(f"Loading tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Loading model from: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
)
# Choose device based on availability (CPU or GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)
@app.post("/predict")
async def predict(request: Request):
"""
Endpoint for streaming responses from Falcon-7B-Instruct.
Expects JSON: { "prompt": "<Your prompt>" }
Returns a text/event-stream of tokens (SSE).
"""
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len]
attention_mask = inputs.attention_mask # same shape
def token_generator():
nonlocal input_ids, attention_mask
# Basic generation hyperparameters
temperature = 0.7
top_p = 0.9
max_new_tokens = 30 # Increase if you want longer outputs
for _ in range(max_new_tokens):
with torch.no_grad():
# 1) Forward pass: compute logits for the next token
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
next_token_logits = outputs.logits[:, -1, :]
# 2) Apply temperature scaling
next_token_logits = next_token_logits / temperature
# 3) Convert logits -> probabilities
next_token_probs = F.softmax(next_token_logits, dim=-1)
# 4) Nucleus (top-p) sampling
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
valid_indices = cumulative_probs <= top_p
filtered_probs = sorted_probs[valid_indices]
filtered_indices = sorted_indices[valid_indices]
# 5) If no tokens remain after filtering, fall back to greedy
if len(filtered_probs) == 0:
next_token_id = torch.argmax(next_token_probs)
else:
sampled_id = torch.multinomial(filtered_probs, 1)
next_token_id = filtered_indices[sampled_id]
# 6) Ensure next_token_id has shape [batch_size, 1]
if next_token_id.dim() == 0:
# shape [] => [1]
next_token_id = next_token_id.unsqueeze(0)
# shape [1] => [1,1]
next_token_id = next_token_id.unsqueeze(-1)
# 7) Append the new token to input_ids
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# 8) Update the attention mask
new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
# 9) Decode and yield the generated token
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
yield token + " "
# 10) Stop if EOS token is generated (if the model uses one)
if tokenizer.eos_token_id is not None:
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
break
# Return a StreamingResponse for SSE
return StreamingResponse(token_generator(), media_type="text/plain")