File size: 5,325 Bytes
977cc0a
4468cfe
45123df
5102dda
4468cfe
5102dda
e05b36f
 
 
a69755d
 
 
 
 
 
 
3895f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0ec7a
 
3895f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0ec7a
 
3895f1c
 
 
5102dda
3895f1c
5102dda
4468cfe
 
 
3895f1c
 
 
 
 
4468cfe
 
 
 
 
3895f1c
5102dda
3895f1c
 
4468cfe
45123df
3895f1c
 
 
d638752
8194424
3895f1c
5102dda
 
3895f1c
5102dda
8194424
 
3895f1c
74b564f
 
5102dda
3895f1c
74b564f
8194424
3895f1c
74b564f
5102dda
3895f1c
74b564f
 
8194424
 
 
 
3895f1c
8194424
d638752
74b564f
d638752
 
 
3895f1c
d638752
3895f1c
 
 
 
d638752
3895f1c
8194424
5102dda
3895f1c
d638752
 
 
3895f1c
74b564f
 
5102dda
3895f1c
8194424
 
 
5102dda
3895f1c
45123df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

app = FastAPI()

# Retrieve the token from environment variable
hf_token = os.environ.get("HF_AUTH_TOKEN", None)
if hf_token is None:
    print("WARNING: No HF_AUTH_TOKEN found in environment. "
          "Make sure to set a Hugging Face token if the model is gated.")


# -------------------------------------------------------------------------
# Update this to the Llama 2 Chat model you prefer. This example uses the
# 7B chat version. For larger models (13B, 70B), ensure you have enough RAM.
# -------------------------------------------------------------------------
model_name = "meta-llama/Llama-2-7b-chat-hf"

# -------------------------------------------------------------------------
# If the repo is gated, you may need:
#   use_auth_token="YOUR_HF_TOKEN",
#   trust_remote_code=True,
# or you can set environment variables in your HF Space to authenticate.
# -------------------------------------------------------------------------
print(f"Loading model/tokenizer from: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    use_auth_token=hf_token,
)

# -------------------------------------------------------------------------
# If you had GPU available, you might do:
# model = AutoModelForCausalLM.from_pretrained(
#     model_name,
#     torch_dtype=torch.float16,
#     device_map="auto",
#     trust_remote_code=True
# )
# But for CPU, we do a simpler load:
# -------------------------------------------------------------------------
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    use_auth_token=hf_token,
)

# Choose device based on availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.to(device)

@app.post("/predict")
async def predict(request: Request):
    """
    Endpoint for streaming responses from the Llama 2 chat model.
    Expects JSON: { "prompt": "<Your prompt>" }
    Returns a text/event-stream of tokens.
    """
    data = await request.json()
    prompt = data.get("prompt", "")
    if not prompt:
        return {"error": "Prompt is required"}

    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids             # shape: [batch_size, seq_len], typically [1, seq_len]
    attention_mask = inputs.attention_mask   # same shape

    def token_generator():
        """
        A generator that yields tokens one by one for SSE streaming.
        """
        nonlocal input_ids, attention_mask

        # Basic generation hyperparameters
        temperature = 0.7
        top_p = 0.9
        max_new_tokens = 30  # Increase for longer outputs

        for _ in range(max_new_tokens):
            with torch.no_grad():
                # 1) Forward pass: compute logits for next token
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :]

                # 2) Apply temperature scaling
                next_token_logits = next_token_logits / temperature

                # 3) Convert logits -> probabilities
                next_token_probs = F.softmax(next_token_logits, dim=-1)

                # 4) Nucleus (top-p) sampling
                sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                valid_indices = cumulative_probs <= top_p
                filtered_probs = sorted_probs[valid_indices]
                filtered_indices = sorted_indices[valid_indices]

                # 5) If no tokens are valid under top_p, fallback to greedy
                if len(filtered_probs) == 0:
                    next_token_id = torch.argmax(next_token_probs)
                else:
                    sampled_id = torch.multinomial(filtered_probs, 1)
                    next_token_id = filtered_indices[sampled_id]

                # 6) Ensure next_token_id has shape [batch_size, 1]
                if next_token_id.dim() == 0:
                    # shape [] => [1]
                    next_token_id = next_token_id.unsqueeze(0)
                # shape [1] => [1,1]
                next_token_id = next_token_id.unsqueeze(-1)

                # 7) Append token to input_ids
                input_ids = torch.cat([input_ids, next_token_id], dim=-1)

                # 8) Update attention_mask for the new token
                new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
                attention_mask = torch.cat([attention_mask, new_mask], dim=-1)

                # 9) Decode and yield
                token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
                yield token + " "

                # 10) Stop if we encounter EOS
                if tokenizer.eos_token_id is not None:
                    if next_token_id.squeeze().item() == tokenizer.eos_token_id:
                        break

    # Return a StreamingResponse for SSE
    return StreamingResponse(token_generator(), media_type="text/plain")