rkihacker commited on
Commit
415ec30
·
verified ·
1 Parent(s): 6c8dce7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +102 -104
main.py CHANGED
@@ -9,156 +9,154 @@ from typing import List, Dict, Any, Optional, Union, Literal
9
  from dotenv import load_dotenv
10
  from sse_starlette.sse import EventSourceResponse
11
 
12
- # Load environment variables from .env file
13
  load_dotenv()
14
-
15
- # --- Configuration ---
16
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
17
  if not REPLICATE_API_TOKEN:
18
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
19
 
20
- # --- FastAPI App Initialization ---
21
- app = FastAPI(
22
- title="Replicate to OpenAI Compatibility Layer",
23
- version="4.0.0 (Stable & Correct)",
24
- )
25
 
26
  # --- Pydantic Models ---
27
  class ModelCard(BaseModel):
28
  id: str; object: str = "model"; created: int = Field(default_factory=lambda: int(time.time())); owned_by: str = "replicate"
29
-
30
  class ModelList(BaseModel):
31
  object: str = "list"; data: List[ModelCard] = []
32
-
33
  class ChatMessage(BaseModel):
34
  role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]
35
-
36
  class OpenAIChatCompletionRequest(BaseModel):
37
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
38
 
39
- # --- Model Mapping (Simplified for direct endpoint usage) ---
 
40
  SUPPORTED_MODELS = {
41
- "llama3-8b-instruct": {
42
- "id": "meta/meta-llama-3-8b-instruct",
43
- "input_type": "messages"
44
- },
45
- "claude-4.5-haiku": {
46
- "id": "anthropic/claude-4.5-haiku",
47
- "input_type": "prompt"
48
- }
49
  }
50
 
51
- # --- Helper Functions ---
52
- def prepare_replicate_input(request: OpenAIChatCompletionRequest, model_details: dict) -> Dict[str, Any]:
53
- """Prepares the 'input' dictionary for Replicate, handling model-specific formats."""
54
- input_payload = {}
55
 
56
- if model_details["input_type"] == "prompt":
 
57
  prompt_parts = []
58
  system_prompt = None
59
  for msg in request.messages:
60
- if msg.role == "system": system_prompt = str(msg.content)
61
- elif msg.role == "user": prompt_parts.append(f"User: {msg.content}")
62
- elif msg.role == "assistant": prompt_parts.append(f"Assistant: {msg.content}")
 
 
 
 
 
 
 
 
 
 
 
63
  prompt_parts.append("Assistant:")
64
- input_payload["prompt"] = "\n".join(prompt_parts)
65
- if system_prompt: input_payload["system_prompt"] = system_prompt
66
- else: # "messages"
67
- input_payload["messages"] = [msg.dict() for msg in request.messages]
68
-
69
- if request.max_tokens is not None: input_payload["max_new_tokens"] = request.max_tokens
70
- if request.temperature is not None: input_payload["temperature"] = request.temperature
71
- if request.top_p is not None: input_payload["top_p"] = request.top_p
72
- return input_payload
 
 
 
 
 
 
73
 
74
- async def stream_replicate_native_sse(model_id: str, input_payload: dict):
75
- """Connects to Replicate's native SSE stream using the model-specific endpoint."""
76
- url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
 
77
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
78
 
79
- # The request body is now simple and correct
80
- request_body = {"input": input_payload, "stream": True}
81
-
82
- async with httpx.AsyncClient(timeout=300) as client:
83
- prediction = None
84
  try:
85
- response = await client.post(url, headers=headers, json=request_body)
 
86
  response.raise_for_status()
87
  prediction = response.json()
88
  stream_url = prediction.get("urls", {}).get("stream")
 
89
 
90
  if not stream_url:
91
- error_detail = prediction.get("detail", "Failed to get stream URL.")
92
- yield json.dumps({"error": {"message": error_detail}})
93
- return
94
  except httpx.HTTPStatusError as e:
95
- try: yield json.dumps({"error": {"message": json.dumps(e.response.json())}})
96
- except: yield json.dumps({"error": {"message": e.response.text}})
97
- return
98
-
99
- try:
100
- async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}) as sse:
101
- sse.raise_for_status()
102
- current_event = ""
103
- async for line in sse.aiter_lines():
104
- if line.startswith("event:"):
105
- current_event = line[len("event:"):].strip()
106
- elif line.startswith("data:"):
107
- data = line[len("data:"):].strip()
108
- if current_event == "output":
 
 
 
 
109
  try:
110
  content = json.loads(data)
 
 
 
 
111
  chunk = {
112
- "id": prediction["id"], "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
113
  "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
114
  }
115
  yield json.dumps(chunk)
116
- except json.JSONDecodeError:
117
- # Silently ignore malformed or empty data lines
118
- pass
119
- elif current_event == "done":
120
- break
121
- except Exception as e:
122
- yield json.dumps({"error": {"message": f"Streaming error: {str(e)}"}})
123
-
124
- done_chunk = {
125
- "id": prediction["id"] if prediction else "unknown", "object": "chat.completion.chunk", "created": int(time.time()), "model": model_id,
126
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
127
- }
128
- yield json.dumps(done_chunk)
129
  yield "[DONE]"
130
 
131
- # --- API Endpoints ---
132
- @app.get("/v1/models", response_model=ModelList)
133
  async def list_models():
134
- return ModelList(data=[ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()])
135
 
136
  @app.post("/v1/chat/completions")
137
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
138
- model_key = request.model
139
- if model_key not in SUPPORTED_MODELS:
140
- raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
141
 
142
- model_details = SUPPORTED_MODELS[model_key]
143
- replicate_input = prepare_replicate_input(request, model_details)
144
 
145
  if request.stream:
146
- return EventSourceResponse(stream_replicate_native_sse(model_details["id"], replicate_input))
147
-
148
- # Synchronous Request
149
- url = f"https://api.replicate.com/v1/models/{model_details['id']}/predictions"
150
- headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"}
151
-
152
- async with httpx.AsyncClient(timeout=150) as client:
153
- try:
154
- response = await client.post(url, headers=headers, json={"input": replicate_input})
155
- response.raise_for_status()
156
- prediction = response.json()
157
- output = "".join(prediction.get("output", []))
158
- return JSONResponse(content={
159
- "id": prediction["id"], "object": "chat.completion", "created": int(time.time()), "model": model_key,
160
- "choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}],
161
- "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
162
- })
163
- except httpx.HTTPStatusError as e:
164
- raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
 
9
  from dotenv import load_dotenv
10
  from sse_starlette.sse import EventSourceResponse
11
 
12
+ # Load environment variables
13
  load_dotenv()
 
 
14
  REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
15
  if not REPLICATE_API_TOKEN:
16
  raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
17
 
18
+ # FastAPI Init
19
+ app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="4.0.0 (Docs Compliant)")
 
 
 
20
 
21
  # --- Pydantic Models ---
22
  class ModelCard(BaseModel):
23
  id: str; object: str = "model"; created: int = Field(default_factory=lambda: int(time.time())); owned_by: str = "replicate"
 
24
  class ModelList(BaseModel):
25
  object: str = "list"; data: List[ModelCard] = []
 
26
  class ChatMessage(BaseModel):
27
  role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]]
 
28
  class OpenAIChatCompletionRequest(BaseModel):
29
  model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False
30
 
31
+ # --- Supported Models ---
32
+ # Maps OpenAI-friendly names to Replicate model paths
33
  SUPPORTED_MODELS = {
34
+ "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
35
+ "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
 
 
 
 
 
 
36
  }
37
 
38
+ # --- Core Logic ---
39
+ def prepare_replicate_input(request: OpenAIChatCompletionRequest, replicate_model_id: str) -> Dict[str, Any]:
40
+ """Formats the input specifically for the requested Replicate model."""
41
+ payload = {}
42
 
43
+ # Claude on Replicate strictly requires a 'prompt' string, not 'messages' array.
44
+ if "anthropic/claude" in replicate_model_id:
45
  prompt_parts = []
46
  system_prompt = None
47
  for msg in request.messages:
48
+ if msg.role == "system":
49
+ # Extract system prompt if present
50
+ system_prompt = str(msg.content)
51
+ elif msg.role == "user":
52
+ # Handle both simple string content and list content (for potential future vision support)
53
+ content = msg.content
54
+ if isinstance(content, list):
55
+ text_parts = [item.get("text", "") for item in content if item.get("type") == "text"]
56
+ content = " ".join(text_parts)
57
+ prompt_parts.append(f"User: {content}")
58
+ elif msg.role == "assistant":
59
+ prompt_parts.append(f"Assistant: {msg.content}")
60
+
61
+ # Standard Claude prompting convention
62
  prompt_parts.append("Assistant:")
63
+ payload["prompt"] = "\n\n".join(prompt_parts)
64
+ if system_prompt:
65
+ payload["system_prompt"] = system_prompt
66
+
67
+ # Llama 3 and others often support the 'messages' array natively.
68
+ else:
69
+ # Convert Pydantic models to pure dicts
70
+ payload["prompt"] = [msg.dict() for msg in request.messages]
71
+
72
+ # Map common OpenAI parameters to Replicate equivalents
73
+ if request.max_tokens: payload["max_new_tokens"] = request.max_tokens
74
+ if request.temperature: payload["temperature"] = request.temperature
75
+ if request.top_p: payload["top_p"] = request.top_p
76
+
77
+ return payload
78
 
79
+ async def stream_replicate_sse(replicate_model_id: str, input_payload: dict):
80
+ """Handles the full streaming lifecycle using standard Replicate endpoints."""
81
+ # 1. Start Prediction specifically at the named model endpoint
82
+ url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
83
  headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"}
84
 
85
+ async with httpx.AsyncClient(timeout=60.0) as client:
 
 
 
 
86
  try:
87
+ # Explicitly request stream=True in the body, though often implicit
88
+ response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True})
89
  response.raise_for_status()
90
  prediction = response.json()
91
  stream_url = prediction.get("urls", {}).get("stream")
92
+ prediction_id = prediction.get("id")
93
 
94
  if not stream_url:
95
+ yield json.dumps({"error": {"message": "Model did not return a stream URL."}})
96
+ return
97
+
98
  except httpx.HTTPStatusError as e:
99
+ yield json.dumps({"error": {"message": e.response.text, "type": "upstream_error"}})
100
+ return
101
+
102
+ # 2. Connect to the provided Stream URL
103
+ async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse:
104
+ current_event = None
105
+ async for line in sse.aiter_lines():
106
+ if line.startswith("event:"):
107
+ current_event = line[len("event:"):].strip()
108
+ elif line.startswith("data:"):
109
+ data = line[len("data:"):].strip()
110
+
111
+ if current_event == "output":
112
+ # CRITICAL: Wrap in try/except to ignore empty keep-alive lines that crash standard parsers
113
+ try:
114
+ # Replicate sometimes sends raw strings, sometimes JSON.
115
+ # For chat models, it's usually a raw string token.
116
+ # We try to load as JSON first, if it fails, use raw data.
117
  try:
118
  content = json.loads(data)
119
+ except json.JSONDecodeError:
120
+ content = data
121
+
122
+ if content: # Ensure we don't send empty chunks
123
  chunk = {
124
+ "id": prediction_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": replicate_model_id,
125
  "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": None}]
126
  }
127
  yield json.dumps(chunk)
128
+ except Exception:
129
+ pass # Safely ignore malformed lines
130
+
131
+ elif current_event == "done":
132
+ break
133
+
134
+ # 3. Send final [DONE] event
135
+ yield json.dumps({"id": prediction_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]})
 
 
 
 
 
136
  yield "[DONE]"
137
 
138
+ # --- Endpoints ---
139
+ @app.get("/v1/models")
140
  async def list_models():
141
+ return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()])
142
 
143
  @app.post("/v1/chat/completions")
144
  async def create_chat_completion(request: OpenAIChatCompletionRequest):
145
+ if request.model not in SUPPORTED_MODELS:
146
+ raise HTTPException(404, f"Model not found. Available: {list(SUPPORTED_MODELS.keys())}")
 
147
 
148
+ replicate_id = SUPPORTED_MODELS[request.model]
149
+ replicate_input = prepare_replicate_input(request, replicate_id)
150
 
151
  if request.stream:
152
+ return EventSourceResponse(stream_replicate_sse(replicate_id, replicate_input))
153
+
154
+ # Non-streaming fallback
155
+ url = f"https://api.replicate.com/v1/models/{replicate_id}/predictions"
156
+ headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=60"}
157
+ async with httpx.AsyncClient() as client:
158
+ resp = await client.post(url, headers=headers, json={"input": replicate_input})
159
+ if resp.is_error: raise HTTPException(resp.status_code, resp.text)
160
+ pred = resp.json()
161
+ output = "".join(pred.get("output", []))
162
+ return {"id": pred["id"], "choices": [{"message": {"role": "assistant", "content": output}, "finish_reason": "stop"}]}