Marroco93 commited on
Commit
215f4a9
1 Parent(s): 71badbc
Files changed (1) hide show
  1. main.py +23 -4
main.py CHANGED
@@ -3,6 +3,7 @@ from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
 
6
 
7
 
8
  app = FastAPI()
@@ -26,6 +27,7 @@ def format_prompt(message, history):
26
  prompt += f"[INST] {message} [/INST]"
27
  return prompt
28
 
 
29
  def generate(item: Item):
30
  temperature = float(item.temperature)
31
  if temperature < 1e-2:
@@ -43,11 +45,28 @@ def generate(item: Item):
43
 
44
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
45
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
46
-
47
- for response in stream:
48
- yield response.token.text.encode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @app.post("/generate/")
51
  async def generate_text(item: Item):
52
- return StreamingResponse(generate(item), media_type="text/plain")
 
 
53
 
 
3
  from pydantic import BaseModel
4
  from huggingface_hub import InferenceClient
5
  import uvicorn
6
+ import json # Make sure to import json
7
 
8
 
9
  app = FastAPI()
 
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
+
31
  def generate(item: Item):
32
  temperature = float(item.temperature)
33
  if temperature < 1e-2:
 
45
 
46
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
47
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
48
+
49
+ # Initialize a variable to track whether this is the last item
50
+ is_last = False
51
+
52
+ # Since we're yielding JSON, each chunk must be a complete JSON object.
53
+ # We'll iterate over the stream and yield each response as a JSON string.
54
+ for i, response in enumerate(stream):
55
+ # Check if this is the last item by attempting to peek ahead
56
+ is_last = True # Assume it's the last unless proven otherwise in the next iteration
57
+
58
+ # Construct the chunk of data to include the text and completion status
59
+ chunk_data = {
60
+ "text": response.token.text,
61
+ "complete": is_last
62
+ }
63
+
64
+ # Yield this chunk as a JSON-encoded string followed by a newline to separate chunks
65
+ yield json.dumps(chunk_data) + "\n"
66
 
67
  @app.post("/generate/")
68
  async def generate_text(item: Item):
69
+ # Note the change to media_type to indicate we're streaming JSON
70
+ return StreamingResponse(generate(item), media_type="application/x-ndjson")
71
+
72