DataChem commited on
Commit
8194424
·
verified ·
1 Parent(s): 74b564f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -10,6 +10,7 @@ app = FastAPI()
10
  model_name = "EleutherAI/gpt-neo-1.3B" # Replace with your desired model
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
@@ -20,45 +21,60 @@ async def predict(request: Request):
20
  if not prompt:
21
  return {"error": "Prompt is required"}
22
 
23
- # Tokenize the input
24
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
25
  input_ids = inputs.input_ids
26
  attention_mask = inputs.attention_mask
27
 
28
  def token_generator():
 
 
 
 
29
  temperature = 0.7
30
  top_p = 0.9
 
31
 
32
- for _ in range(100): # Limit to 100 tokens
33
- with torch.no_grad(): # Disable gradient computation for inference
 
34
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
35
  next_token_logits = outputs.logits[:, -1, :]
36
 
37
- # Apply temperature and softmax
38
  next_token_logits = next_token_logits / temperature
 
 
39
  next_token_probs = F.softmax(next_token_logits, dim=-1)
40
 
41
- # Apply nucleus sampling (top-p)
42
  sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
43
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
44
- sorted_probs = sorted_probs[cumulative_probs <= top_p]
45
- sorted_indices = sorted_indices[:len(sorted_probs)]
46
 
47
- # Sample next token
48
- if len(sorted_probs) > 0:
49
- next_token_id = sorted_indices[torch.multinomial(sorted_probs, 1)]
 
 
 
 
 
50
  else:
51
- next_token_id = torch.argmax(next_token_probs)
 
 
52
 
53
- # Append the new token to the input sequence
54
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
55
 
56
  # Decode and yield the token
57
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
58
  yield token + " "
59
 
60
- # Stop if the end-of-sequence token is generated
61
- if next_token_id.squeeze().item() == tokenizer.eos_token_id:
62
- break
 
63
 
 
64
  return StreamingResponse(token_generator(), media_type="text/plain")
 
10
  model_name = "EleutherAI/gpt-neo-1.3B" # Replace with your desired model
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name)
13
+
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
 
 
21
  if not prompt:
22
  return {"error": "Prompt is required"}
23
 
24
+ # Tokenize the input and move to correct device
25
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
26
  input_ids = inputs.input_ids
27
  attention_mask = inputs.attention_mask
28
 
29
  def token_generator():
30
+ # Use nonlocal to allow reassigning input_ids inside the nested function
31
+ nonlocal input_ids
32
+
33
+ # Sampling parameters
34
  temperature = 0.7
35
  top_p = 0.9
36
+ max_new_tokens = 30
37
 
38
+ for _ in range(max_new_tokens):
39
+ with torch.no_grad():
40
+ # Forward pass
41
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
42
  next_token_logits = outputs.logits[:, -1, :]
43
 
44
+ # Temperature scaling
45
  next_token_logits = next_token_logits / temperature
46
+
47
+ # Convert logits to probabilities
48
  next_token_probs = F.softmax(next_token_logits, dim=-1)
49
 
50
+ # Apply nucleus (top-p) sampling
51
  sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
52
  cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
 
 
53
 
54
+ # Filter out tokens above the top_p threshold
55
+ valid_indices = cumulative_probs <= top_p
56
+ filtered_probs = sorted_probs[valid_indices]
57
+ filtered_indices = sorted_indices[valid_indices]
58
+
59
+ if len(filtered_probs) == 0:
60
+ # Fallback to greedy if no tokens meet top_p
61
+ next_token_id = torch.argmax(next_token_probs).unsqueeze(-1)
62
  else:
63
+ # Sample from the filtered distribution
64
+ sampled_id = torch.multinomial(filtered_probs, num_samples=1)
65
+ next_token_id = filtered_indices[sampled_id].unsqueeze(-1)
66
 
67
+ # Append the new token to our running sequence
68
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
69
 
70
  # Decode and yield the token
71
  token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
72
  yield token + " "
73
 
74
+ # Stop if EOS token is generated
75
+ if tokenizer.eos_token_id is not None:
76
+ if next_token_id.squeeze().item() == tokenizer.eos_token_id:
77
+ break
78
 
79
+ # Return the streaming response
80
  return StreamingResponse(token_generator(), media_type="text/plain")