DataChem commited on
Commit
74b564f
·
verified ·
1 Parent(s): 5102dda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -34
app.py CHANGED
@@ -13,10 +13,6 @@ model = AutoModelForCausalLM.from_pretrained(model_name)
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
16
- @app.get("/")
17
- def read_root():
18
- return {"Hello": "World"}
19
-
20
  @app.post("/predict")
21
  async def predict(request: Request):
22
  data = await request.json()
@@ -29,45 +25,40 @@ async def predict(request: Request):
29
  input_ids = inputs.input_ids
30
  attention_mask = inputs.attention_mask
31
 
32
- # Generator function to stream tokens
33
  def token_generator():
34
  temperature = 0.7
35
  top_p = 0.9
36
 
37
- for _ in range(100): # Limit the number of generated tokens
38
- # Get the model outputs
39
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
40
- next_token_logits = outputs.logits[:, -1, :] # Logits for the last token
41
-
42
- # Apply temperature scaling
43
- next_token_logits = next_token_logits / temperature
44
 
45
- # Convert logits to probabilities
46
- next_token_probs = F.softmax(next_token_logits, dim=-1)
 
47
 
48
- # Apply top-p nucleus sampling
49
- sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
50
- cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
51
- sorted_probs = sorted_probs[cumulative_probs <= top_p]
52
- sorted_indices = sorted_indices[:len(sorted_probs)]
53
 
54
- # Sample from the filtered distribution
55
- if len(sorted_probs) > 0:
56
- next_token_id = sorted_indices[torch.multinomial(sorted_probs, 1)]
57
- else:
58
- # Fallback to greedy selection if no tokens meet top-p
59
- next_token_id = torch.argmax(next_token_probs)
60
 
61
- # Append the generated token to the input
62
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
63
 
64
- # Decode the token and yield it
65
- token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
66
- yield token + " "
67
 
68
- # Stop if the model generates the end-of-sequence token
69
- if next_token_id.squeeze().item() == tokenizer.eos_token_id:
70
- break
71
 
72
- # Return the generator as a streaming response
73
  return StreamingResponse(token_generator(), media_type="text/plain")
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
 
 
 
 
16
  @app.post("/predict")
17
  async def predict(request: Request):
18
  data = await request.json()
 
25
  input_ids = inputs.input_ids
26
  attention_mask = inputs.attention_mask
27
 
 
28
  def token_generator():
29
  temperature = 0.7
30
  top_p = 0.9
31
 
32
+ for _ in range(100): # Limit to 100 tokens
33
+ with torch.no_grad(): # Disable gradient computation for inference
34
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
35
+ next_token_logits = outputs.logits[:, -1, :]
 
 
 
36
 
37
+ # Apply temperature and softmax
38
+ next_token_logits = next_token_logits / temperature
39
+ next_token_probs = F.softmax(next_token_logits, dim=-1)
40
 
41
+ # Apply nucleus sampling (top-p)
42
+ sorted_probs, sorted_indices = torch.sort(next_token_probs, descending=True)
43
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
44
+ sorted_probs = sorted_probs[cumulative_probs <= top_p]
45
+ sorted_indices = sorted_indices[:len(sorted_probs)]
46
 
47
+ # Sample next token
48
+ if len(sorted_probs) > 0:
49
+ next_token_id = sorted_indices[torch.multinomial(sorted_probs, 1)]
50
+ else:
51
+ next_token_id = torch.argmax(next_token_probs)
 
52
 
53
+ # Append the new token to the input sequence
54
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1)], dim=-1)
55
 
56
+ # Decode and yield the token
57
+ token = tokenizer.decode(next_token_id.squeeze(), skip_special_tokens=True)
58
+ yield token + " "
59
 
60
+ # Stop if the end-of-sequence token is generated
61
+ if next_token_id.squeeze().item() == tokenizer.eos_token_id:
62
+ break
63
 
 
64
  return StreamingResponse(token_generator(), media_type="text/plain")