Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -21,57 +21,66 @@ async def predict(request: Request):
|
|
21 |
if not prompt:
|
22 |
return {"error": "Prompt is required"}
|
23 |
|
24 |
-
#
|
25 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
26 |
-
input_ids = inputs.input_ids
|
27 |
-
attention_mask = inputs.attention_mask
|
28 |
|
29 |
def token_generator():
|
30 |
-
|
31 |
-
nonlocal input_ids
|
32 |
|
33 |
-
#
|
34 |
temperature = 0.7
|
35 |
top_p = 0.9
|
36 |
max_new_tokens = 30
|
37 |
|
38 |
for _ in range(max_new_tokens):
|
39 |
with torch.no_grad():
|
40 |
-
# Forward pass
|
41 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
42 |
next_token_logits = outputs.logits[:, -1, :]
|
43 |
|
44 |
-
#
|
45 |
next_token_logits = next_token_logits / temperature
|
46 |
|
47 |
-
# Convert logits
|
48 |
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
49 |
|
50 |
-
# Apply
|
51 |
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
|
52 |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
53 |
-
|
54 |
-
# Filter out tokens above the top_p threshold
|
55 |
valid_indices = cumulative_probs <= top_p
|
56 |
filtered_probs = sorted_probs[valid_indices]
|
57 |
filtered_indices = sorted_indices[valid_indices]
|
58 |
|
59 |
if len(filtered_probs) == 0:
|
60 |
-
# Fallback to greedy if
|
61 |
-
next_token_id = torch.argmax(next_token_probs)
|
62 |
else:
|
63 |
-
# Sample from the filtered distribution
|
64 |
-
sampled_id = torch.multinomial(filtered_probs,
|
65 |
-
next_token_id = filtered_indices[sampled_id]
|
66 |
-
|
67 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
|
69 |
|
70 |
-
#
|
|
|
|
|
|
|
|
|
|
|
71 |
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
|
72 |
yield token + " "
|
73 |
|
74 |
-
# Stop if EOS token
|
75 |
if tokenizer.eos_token_id is not None:
|
76 |
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
|
77 |
break
|
|
|
21 |
if not prompt:
|
22 |
return {"error": "Prompt is required"}
|
23 |
|
24 |
+
# Initial tokenization on the prompt
|
25 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
26 |
+
input_ids = inputs.input_ids # Shape: [batch_size, seq_len], often [1, seq_len]
|
27 |
+
attention_mask = inputs.attention_mask # Same shape as input_ids
|
28 |
|
29 |
def token_generator():
|
30 |
+
nonlocal input_ids, attention_mask
|
|
|
31 |
|
32 |
+
# Generation hyperparameters
|
33 |
temperature = 0.7
|
34 |
top_p = 0.9
|
35 |
max_new_tokens = 30
|
36 |
|
37 |
for _ in range(max_new_tokens):
|
38 |
with torch.no_grad():
|
39 |
+
# Forward pass: compute logits for the last token
|
40 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
41 |
next_token_logits = outputs.logits[:, -1, :]
|
42 |
|
43 |
+
# Apply temperature
|
44 |
next_token_logits = next_token_logits / temperature
|
45 |
|
46 |
+
# Convert logits -> probabilities
|
47 |
next_token_probs = F.softmax(next_token_logits, dim=-1)
|
48 |
|
49 |
+
# Apply top-p (nucleus) sampling
|
50 |
sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
|
51 |
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
|
|
|
52 |
valid_indices = cumulative_probs <= top_p
|
53 |
filtered_probs = sorted_probs[valid_indices]
|
54 |
filtered_indices = sorted_indices[valid_indices]
|
55 |
|
56 |
if len(filtered_probs) == 0:
|
57 |
+
# Fallback to greedy if nothing meets top_p
|
58 |
+
next_token_id = torch.argmax(next_token_probs)
|
59 |
else:
|
60 |
+
# Sample a token from the filtered distribution
|
61 |
+
sampled_id = torch.multinomial(filtered_probs, 1)
|
62 |
+
next_token_id = filtered_indices[sampled_id]
|
63 |
+
|
64 |
+
# At this point, next_token_id might be shape [] (scalar) or [1].
|
65 |
+
# We need [batch_size, 1], so if it's just a scalar, unsqueeze(0).
|
66 |
+
if next_token_id.dim() == 0:
|
67 |
+
next_token_id = next_token_id.unsqueeze(0) # shape [1]
|
68 |
+
next_token_id = next_token_id.unsqueeze(-1) # shape [1,1]
|
69 |
+
|
70 |
+
# Append the new token to input_ids
|
71 |
+
# input_ids: [1, seq_len], next_token_id: [1,1] => final shape [1, seq_len+1]
|
72 |
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
|
73 |
|
74 |
+
# Also update the attention mask so the model attends to the new token
|
75 |
+
# shape: [1, seq_len+1]
|
76 |
+
new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
|
77 |
+
attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
|
78 |
+
|
79 |
+
# Decode and yield the token for streaming
|
80 |
token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
|
81 |
yield token + " "
|
82 |
|
83 |
+
# Stop if we hit the EOS token
|
84 |
if tokenizer.eos_token_id is not None:
|
85 |
if next_token_id.squeeze().item() == tokenizer.eos_token_id:
|
86 |
break
|