DataChem commited on
Commit
d638752
·
verified ·
1 Parent(s): 8194424

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -21,57 +21,66 @@ async def predict(request: Request):
21
  if not prompt:
22
  return {"error": "Prompt is required"}
23
 
24
- # Tokenize the input and move to correct device
25
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
26
- input_ids = inputs.input_ids
27
- attention_mask = inputs.attention_mask
28
 
29
  def token_generator():
30
- # Use nonlocal to allow reassigning input_ids inside the nested function
31
- nonlocal input_ids
32
 
33
- # Sampling parameters
34
  temperature = 0.7
35
  top_p = 0.9
36
  max_new_tokens = 30
37
 
38
  for _ in range(max_new_tokens):
39
  with torch.no_grad():
40
- # Forward pass
41
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
42
  next_token_logits = outputs.logits[:, -1, :]
43
 
44
- # Temperature scaling
45
  next_token_logits = next_token_logits / temperature
46
 
47
- # Convert logits to probabilities
48
  next_token_probs = F.softmax(next_token_logits, dim=-1)
49
 
50
- # Apply nucleus (top-p) sampling
51
  sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
52
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
53
-
54
- # Filter out tokens above the top_p threshold
55
  valid_indices = cumulative_probs <= top_p
56
  filtered_probs = sorted_probs[valid_indices]
57
  filtered_indices = sorted_indices[valid_indices]
58
 
59
  if len(filtered_probs) == 0:
60
- # Fallback to greedy if no tokens meet top_p
61
- next_token_id = torch.argmax(next_token_probs).unsqueeze(-1)
62
  else:
63
- # Sample from the filtered distribution
64
- sampled_id = torch.multinomial(filtered_probs, num_samples=1)
65
- next_token_id = filtered_indices[sampled_id].unsqueeze(-1)
66
-
67
- # Append the new token to our running sequence
 
 
 
 
 
 
 
68
  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
69
 
70
- # Decode and yield the token
 
 
 
 
 
71
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
72
  yield token + " "
73
 
74
- # Stop if EOS token is generated
75
  if tokenizer.eos_token_id is not None:
76
  if next_token_id.squeeze().item() == tokenizer.eos_token_id:
77
  break
 
21
  if not prompt:
22
  return {"error": "Prompt is required"}
23
 
24
+ # Initial tokenization on the prompt
25
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
26
+ input_ids = inputs.input_ids # Shape: [batch_size, seq_len], often [1, seq_len]
27
+ attention_mask = inputs.attention_mask # Same shape as input_ids
28
 
29
  def token_generator():
30
+ nonlocal input_ids, attention_mask
 
31
 
32
+ # Generation hyperparameters
33
  temperature = 0.7
34
  top_p = 0.9
35
  max_new_tokens = 30
36
 
37
  for _ in range(max_new_tokens):
38
  with torch.no_grad():
39
+ # Forward pass: compute logits for the last token
40
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
41
  next_token_logits = outputs.logits[:, -1, :]
42
 
43
+ # Apply temperature
44
  next_token_logits = next_token_logits / temperature
45
 
46
+ # Convert logits -> probabilities
47
  next_token_probs = F.softmax(next_token_logits, dim=-1)
48
 
49
+ # Apply top-p (nucleus) sampling
50
  sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
51
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
 
 
52
  valid_indices = cumulative_probs <= top_p
53
  filtered_probs = sorted_probs[valid_indices]
54
  filtered_indices = sorted_indices[valid_indices]
55
 
56
  if len(filtered_probs) == 0:
57
+ # Fallback to greedy if nothing meets top_p
58
+ next_token_id = torch.argmax(next_token_probs)
59
  else:
60
+ # Sample a token from the filtered distribution
61
+ sampled_id = torch.multinomial(filtered_probs, 1)
62
+ next_token_id = filtered_indices[sampled_id]
63
+
64
+ # At this point, next_token_id might be shape [] (scalar) or [1].
65
+ # We need [batch_size, 1], so if it's just a scalar, unsqueeze(0).
66
+ if next_token_id.dim() == 0:
67
+ next_token_id = next_token_id.unsqueeze(0) # shape [1]
68
+ next_token_id = next_token_id.unsqueeze(-1) # shape [1,1]
69
+
70
+ # Append the new token to input_ids
71
+ # input_ids: [1, seq_len], next_token_id: [1,1] => final shape [1, seq_len+1]
72
  input_ids = torch.cat([input_ids, next_token_id], dim=-1)
73
 
74
+ # Also update the attention mask so the model attends to the new token
75
+ # shape: [1, seq_len+1]
76
+ new_mask = attention_mask.new_ones((attention_mask.size(0), 1))
77
+ attention_mask = torch.cat([attention_mask, new_mask], dim=-1)
78
+
79
+ # Decode and yield the token for streaming
80
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
81
  yield token + " "
82
 
83
+ # Stop if we hit the EOS token
84
  if tokenizer.eos_token_id is not None:
85
  if next_token_id.squeeze().item() == tokenizer.eos_token_id:
86
  break