Marroco93 commited on
Commit
d0c61b6
1 Parent(s): 215f4a9
Files changed (1) hide show
  1. main.py +11 -17
main.py CHANGED
@@ -27,6 +27,7 @@ def format_prompt(message, history):
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
 
30
 
31
  def generate(item: Item):
32
  temperature = float(item.temperature)
@@ -45,28 +46,21 @@ def generate(item: Item):
45
 
46
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
47
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
48
-
49
- # Initialize a variable to track whether this is the last item
50
- is_last = False
51
-
52
- # Since we're yielding JSON, each chunk must be a complete JSON object.
53
- # We'll iterate over the stream and yield each response as a JSON string.
54
- for i, response in enumerate(stream):
55
- # Check if this is the last item by attempting to peek ahead
56
- is_last = True # Assume it's the last unless proven otherwise in the next iteration
57
-
58
- # Construct the chunk of data to include the text and completion status
59
- chunk_data = {
60
  "text": response.token.text,
61
- "complete": is_last
62
  }
63
-
64
- # Yield this chunk as a JSON-encoded string followed by a newline to separate chunks
65
- yield json.dumps(chunk_data) + "\n"
66
 
67
  @app.post("/generate/")
68
  async def generate_text(item: Item):
69
- # Note the change to media_type to indicate we're streaming JSON
70
  return StreamingResponse(generate(item), media_type="application/x-ndjson")
71
 
72
 
 
 
27
  prompt += f"[INST] {message} [/INST]"
28
  return prompt
29
 
30
+ import json # Import the JSON module
31
 
32
  def generate(item: Item):
33
  temperature = float(item.temperature)
 
46
 
47
  formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
48
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
49
+
50
+ # Convert stream to a list to check if it's the last element
51
+ responses = list(stream)
52
+ for i, response in enumerate(responses):
53
+ # Prepare the chunk as a JSON object
54
+ chunk = {
 
 
 
 
 
 
55
  "text": response.token.text,
56
+ "complete": i == len(responses) - 1 # True if this is the last chunk
57
  }
58
+ # Yield the JSON-encoded string with a newline to separate chunks
59
+ yield json.dumps(chunk).encode("utf-8") + b"\n"
 
60
 
61
  @app.post("/generate/")
62
  async def generate_text(item: Item):
 
63
  return StreamingResponse(generate(item), media_type="application/x-ndjson")
64
 
65
 
66
+