Spaces:
Paused
Paused
File size: 5,325 Bytes
977cc0a 4468cfe 45123df 5102dda 4468cfe 5102dda e05b36f a69755d 3895f1c 8e0ec7a 3895f1c 8e0ec7a 3895f1c 5102dda 3895f1c 5102dda 4468cfe 3895f1c 4468cfe 3895f1c 5102dda 3895f1c 4468cfe 45123df 3895f1c d638752 8194424 3895f1c 5102dda 3895f1c 5102dda 8194424 3895f1c 74b564f 5102dda 3895f1c 74b564f 8194424 3895f1c 74b564f 5102dda 3895f1c 74b564f 8194424 3895f1c 8194424 d638752 74b564f d638752 3895f1c d638752 3895f1c d638752 3895f1c 8194424 5102dda 3895f1c d638752 3895f1c 74b564f 5102dda 3895f1c 8194424 5102dda 3895f1c 45123df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
app = FastAPI()
# Retrieve the token from environment variable
hf_token = os.environ.get("HF_AUTH_TOKEN", None)
if hf_token is None:
print("WARNING: No HF_AUTH_TOKEN found in environment. "
"Make sure to set a Hugging Face token if the model is gated.")
# -------------------------------------------------------------------------
# Update this to the Llama 2 Chat model you prefer. This example uses the
# 7B chat version. For larger models (13B, 70B), ensure you have enough RAM.
# -------------------------------------------------------------------------
model_name = "meta-llama/Llama-2-7b-chat-hf"
# -------------------------------------------------------------------------
# If the repo is gated, you may need:
# use_auth_token="YOUR_HF_TOKEN",
# trust_remote_code=True,
# or you can set environment variables in your HF Space to authenticate.
# -------------------------------------------------------------------------
print(f"Loading model/tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
use_auth_token=hf_token,
)
# -------------------------------------------------------------------------
# If you had GPU available, you might do:
# model = AutoModelForCausalLM.from_pretrained(
# model_name,
# torch_dtype=torch.float16,
# device_map="auto",
# trust_remote_code=True
# )
# But for CPU, we do a simpler load:
# -------------------------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
use_auth_token=hf_token,
)
# Choose device based on availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)
@app.post("/predict")
async def predict(request: Request):
"""
Endpoint for streaming responses from the Llama 2 chat model.
Expects JSON: { "prompt": "<Your prompt>" }
Returns a text/event-stream of tokens.
"""
data = await request.json()
prompt = data.get("prompt", "")
if not prompt:
return {"error": "Prompt is required"}
# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids # shape: [batch_size, seq_len], typically [1, seq_len]
attention_mask = inputs.attention_mask # same shape
def token_generator():
"""
A generator that yields tokens one by one for SSE streaming.
"""
nonlocal input_ids, attention_mask
# Basic generation hyperparameters
temperature = 0.7
top_p = 0.9
max_new_tokens = 30 # Increase for longer outputs
for _ in range(max_new_tokens):
with torch.no_grad():
# 1) Forward pass: compute logits for next token
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
next_token_logits = outputs.logits[:, -1, :]
# 2) Apply temperature scaling
next_token_logits = next_token_logits / temperature
# 3) Convert logits -> probabilities
next_token_probs = F.softmax(next_token_logits, dim=-1)
# 4) Nucleus (top-p) sampling
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
valid_indices = cumulative_probs <= top_p
filtered_probs = sorted_probs[valid_indices]
filtered_indices = sorted_indices[valid_indices]
# 5) If no tokens are valid under top_p, fallback to greedy
if len(filtered_probs) == 0:
next_token_id = torch.argmax(next_token_probs)
else:
sampled_id = torch.multinomial(filtered_probs, 1)
next_token_id = filtered_indices[sampled_id]
# 6) Ensure next_token_id has shape [batch_size, 1]
if next_token_id.dim() == 0:
# shape [] => [1]
next_token_id = next_token_id.unsqueeze(0)
# shape [1] => [1,1]
next_token_id = next_token_id.unsqueeze(-1)
# 7) Append token to input_ids
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# 8) Update attention_mask for the new token
new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
# 9) Decode and yield
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
yield token + " "
# 10) Stop if we encounter EOS
if tokenizer.eos_token_id is not None:
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
break
# Return a StreamingResponse for SSE
return StreamingResponse(token_generator(), media_type="text/plain")
|