rkihacker commited on
Commit
2767573
·
verified ·
1 Parent(s): dafbe9c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -18
main.py CHANGED
@@ -17,7 +17,7 @@ if not REPLICATE_API_TOKEN:
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
- app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.0.0 (Definitive Streaming Fix)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
@@ -81,10 +81,14 @@ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, A
81
  return payload
82
 
83
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
- """Handles the full streaming lifecycle with correct whitespace preservation."""
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
 
 
 
 
88
  async with httpx.AsyncClient(timeout=60.0) as client:
89
  try:
90
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
@@ -113,11 +117,8 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
113
  if line.startswith("event:"):
114
  current_event = line[len("event:"):].strip()
115
  elif line.startswith("data:"):
116
- # FIXED: Preserve all whitespace including leading/trailing spaces
117
- raw_data = line[5:] # Remove "data:" prefix
118
-
119
- # Remove only the optional single space after data: if present
120
- # This is per SSE spec and preserves actual content spaces
121
  if raw_data.startswith(" "):
122
  data_content = raw_data[1:] # Remove the first space only
123
  else:
@@ -129,13 +130,13 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
129
 
130
  content_token = ""
131
  try:
132
- # Handle JSON-encoded strings properly (including spaces)
133
  content_token = json.loads(data_content)
134
  except (json.JSONDecodeError, TypeError):
135
- # Handle plain text tokens (preserve as-is)
136
  content_token = data_content
137
 
138
- # Create chunk with exact format you specified
139
  chunk = {
140
  "choices": [{
141
  "delta": {"content": content_token},
@@ -145,15 +146,18 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
145
  "native_finish_reason": None
146
  }],
147
  "created": int(time.time()),
148
- "id": f"gen-{int(time.time())}-{prediction_id[-12:]}", # Format like your example
149
  "model": replicate_model_id,
150
  "object": "chat.completion.chunk",
151
  "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
152
  }
153
- # FIXED: Yield only the JSON data, let EventSourceResponse handle the SSE formatting
154
  yield json.dumps(chunk)
155
 
156
  elif current_event == "done":
 
 
 
 
157
  # Send usage chunk before done
158
  usage_chunk = {
159
  "choices": [{
@@ -170,7 +174,7 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
170
  "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate",
171
  "usage": {
172
  "cache_discount": 0,
173
- "completion_tokens": 0,
174
  "completion_tokens_details": {"image_tokens": 0, "reasoning_tokens": 0},
175
  "cost": 0,
176
  "cost_details": {
@@ -178,11 +182,12 @@ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
178
  "upstream_inference_cost": None,
179
  "upstream_inference_prompt_cost": 0
180
  },
181
- "input_tokens": 0,
182
  "is_byok": False,
183
- "prompt_tokens": 0,
184
  "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
185
- "total_tokens": 0
 
186
  }
187
  }
188
  yield json.dumps(usage_chunk)
@@ -226,19 +231,33 @@ async def create_chat_completion(request: OpenAIChatCompletionRequest):
226
  if request.stream:
227
  return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream")
228
 
229
- # Non-streaming fallback
230
  url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
231
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
 
 
232
  async with httpx.AsyncClient() as client:
233
  try:
234
  resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0)
235
  resp.raise_for_status()
236
  pred = resp.json()
237
  output = "".join(pred.get("output", []))
 
 
 
 
 
 
 
238
  return {
239
  "id": pred.get("id"), "object": "chat.completion", "created": int(time.time()), "model": request.model,
240
  "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
241
- "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
 
 
 
 
 
242
  }
243
  except httpx.HTTPStatusError as e:
244
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")
 
17
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
18
 
19
  # FastAPI Init
20
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.1.0 (Enhanced Token Tracking)")
21
 
22
  # --- Pydantic Models ---
23
  class ModelCard(BaseModel):
 
81
  return payload
82
 
83
  async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
84
+ """Handles the full streaming lifecycle with enhanced token tracking and timing."""
85
  url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
86
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
87
 
88
+ start_time = time.time()
89
+ prompt_tokens = len(input_payload.get("prompt", "")) // 4 # Rough estimation
90
+ completion_tokens = 0
91
+
92
  async with httpx.AsyncClient(timeout=60.0) as client:
93
  try:
94
  response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
 
117
  if line.startswith("event:"):
118
  current_event = line[len("event:"):].strip()
119
  elif line.startswith("data:"):
120
+ # Remove "data:" prefix and optional space
121
+ raw_data = line[5:] # Remove "data:"
 
 
 
122
  if raw_data.startswith(" "):
123
  data_content = raw_data[1:] # Remove the first space only
124
  else:
 
130
 
131
  content_token = ""
132
  try:
133
+ # Handle JSON-encoded strings properly
134
  content_token = json.loads(data_content)
135
  except (json.JSONDecodeError, TypeError):
136
+ # Handle plain text tokens
137
  content_token = data_content
138
 
139
+ completion_tokens += 1
140
  chunk = {
141
  "choices": [{
142
  "delta": {"content": content_token},
 
146
  "native_finish_reason": None
147
  }],
148
  "created": int(time.time()),
149
+ "id": f"gen-{int(time.time())}-{prediction_id[-12:]}",
150
  "model": replicate_model_id,
151
  "object": "chat.completion.chunk",
152
  "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate"
153
  }
 
154
  yield json.dumps(chunk)
155
 
156
  elif current_event == "done":
157
+ # Calculate timing
158
+ end_time = time.time()
159
+ inference_time = end_time - start_time
160
+
161
  # Send usage chunk before done
162
  usage_chunk = {
163
  "choices": [{
 
174
  "provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate",
175
  "usage": {
176
  "cache_discount": 0,
177
+ "completion_tokens": completion_tokens,
178
  "completion_tokens_details": {"image_tokens": 0, "reasoning_tokens": 0},
179
  "cost": 0,
180
  "cost_details": {
 
182
  "upstream_inference_cost": None,
183
  "upstream_inference_prompt_cost": 0
184
  },
185
+ "input_tokens": prompt_tokens,
186
  "is_byok": False,
187
+ "prompt_tokens": prompt_tokens,
188
  "prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
189
+ "total_tokens": prompt_tokens + completion_tokens,
190
+ "inference_time": round(inference_time, 3)
191
  }
192
  }
193
  yield json.dumps(usage_chunk)
 
231
  if request.stream:
232
  return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream")
233
 
234
+ # Non-streaming fallback with usage data
235
  url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions"
236
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
237
+ start_time = time.time()
238
+
239
  async with httpx.AsyncClient() as client:
240
  try:
241
  resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0)
242
  resp.raise_for_status()
243
  pred = resp.json()
244
  output = "".join(pred.get("output", []))
245
+
246
+ # Calculate timing and tokens
247
+ end_time = time.time()
248
+ inference_time = end_time - start_time
249
+ prompt_tokens = len(input_payload.get("prompt", "")) // 4 # Rough estimation
250
+ completion_tokens = len(output) // 4 # Rough estimation
251
+
252
  return {
253
  "id": pred.get("id"), "object": "chat.completion", "created": int(time.time()), "model": request.model,
254
  "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
255
+ "usage": {
256
+ "prompt_tokens": prompt_tokens,
257
+ "completion_tokens": completion_tokens,
258
+ "total_tokens": prompt_tokens + completion_tokens,
259
+ "inference_time": round(inference_time, 3)
260
+ }
261
  }
262
  except httpx.HTTPStatusError as e:
263
  raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}")