DataChem commited on
Commit
45123df
·
verified ·
1 Parent(s): f5371d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI, Request
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import torch
4
 
5
  app = FastAPI()
@@ -23,8 +24,17 @@ async def predict(request: Request):
23
  # Tokenize the input
24
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Use "cuda" if GPU is enabled
25
 
26
- # Generate tokens
27
- outputs = model.generate(inputs.input_ids, max_length=40, num_return_sequences=1)
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
-
30
- return {"response": response}
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Request
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from fastapi.responses import StreamingResponse
4
  import torch
5
 
6
  app = FastAPI()
 
24
  # Tokenize the input
25
  inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Use "cuda" if GPU is enabled
26
 
27
+ # Generator function to stream tokens
28
+ def token_generator():
29
+ outputs = model.generate(
30
+ inputs.input_ids,
31
+ max_length=40,
32
+ do_sample=True,
33
+ num_return_sequences=1
34
+ )
35
+ for token_id in outputs[0]:
36
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
37
+ yield f"{token} "
38
+
39
+ # Return StreamingResponse
40
+ return StreamingResponse(token_generator(), media_type="text/plain")