rkihacker commited on
Commit
89b138d
·
verified ·
1 Parent(s): 1f3c557

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +279 -0
main.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ import json
4
+ import time
5
+ from fastapi import FastAPI, Request, HTTPException, Header
6
+ from fastapi.responses import JSONResponse
7
+ from pydantic import BaseModel, Field
8
+ from typing import List, Dict, Any, Optional, Union, Literal
9
+ from dotenv import load_dotenv
10
+ from sse_starlette.sse import EventSourceResponse
11
+
12
+ # Load environment variables from .env file
13
+ load_dotenv()
14
+
15
+ # --- Configuration ---
16
+ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
17
+ if not REPLICATE_API_TOKEN:
18
+ raise ValueError("REPLICATE_API_TOKEN environment variable not set.")
19
+
20
+ # --- FastAPI App Initialization ---
21
+ app = FastAPI(
22
+ title="Replicate to OpenAI Compatibility Layer",
23
+ version="1.0.0",
24
+ )
25
+
26
+ # --- Pydantic Models for OpenAI Compatibility ---
27
+
28
+ # /v1/models endpoint
29
+ class ModelCard(BaseModel):
30
+ id: str
31
+ object: str = "model"
32
+ created: int = Field(default_factory=lambda: int(time.time()))
33
+ owned_by: str = "replicate"
34
+
35
+ class ModelList(BaseModel):
36
+ object: str = "list"
37
+ data: List[ModelCard] = []
38
+
39
+ # /v1/chat/completions endpoint
40
+ class ChatMessage(BaseModel):
41
+ role: Literal["system", "user", "assistant", "tool"]
42
+ content: Union[str, List[Dict[str, Any]]]
43
+
44
+ class ToolFunction(BaseModel):
45
+ name: str
46
+ description: str
47
+ parameters: Dict[str, Any]
48
+
49
+ class Tool(BaseModel):
50
+ type: Literal["function"]
51
+ function: ToolFunction
52
+
53
+ class OpenAIChatCompletionRequest(BaseModel):
54
+ model: str
55
+ messages: List[ChatMessage]
56
+ temperature: Optional[float] = 0.7
57
+ top_p: Optional[float] = 1.0
58
+ max_tokens: Optional[int] = None
59
+ stream: Optional[bool] = False
60
+ tools: Optional[List[Tool]] = None
61
+ tool_choice: Optional[Union[str, Dict]] = None
62
+
63
+ # --- Replicate Model Mapping ---
64
+ # We hardcode the models we want to expose.
65
+ SUPPORTED_MODELS = {
66
+ "llama3-8b-instruct": "meta/meta-llama-3-8b-instruct",
67
+ "claude-4.5-haiku": "anthropic/claude-4.5-haiku"
68
+ }
69
+
70
+
71
+ # --- Helper Functions ---
72
+
73
+ def format_tools_for_prompt(tools: List[Tool]) -> str:
74
+ """Converts OpenAI tools to a string for the system prompt."""
75
+ if not tools:
76
+ return ""
77
+
78
+ prompt = "You have access to the following tools. To use a tool, respond with a JSON object in the following format:\n"
79
+ prompt += '{"type": "tool_call", "name": "tool_name", "arguments": {"arg_name": "value"}}\n\n'
80
+ prompt += "Available tools:\n"
81
+ for tool in tools:
82
+ prompt += json.dumps(tool.function.dict(), indent=2) + "\n"
83
+ return prompt
84
+
85
+ def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]:
86
+ """Prepares the input payload for the Replicate API."""
87
+ input_data = {}
88
+ prompt_parts = []
89
+ system_prompt = ""
90
+
91
+ # Handle messages, separating system, user, assistant and vision content
92
+ image_url = None
93
+ for message in request.messages:
94
+ if message.role == "system":
95
+ system_prompt += message.content + "\n"
96
+ elif message.role == "user":
97
+ if isinstance(message.content, list): # Vision support
98
+ for item in message.content:
99
+ if item.get("type") == "text":
100
+ prompt_parts.append(f"User: {item.get('text', '')}")
101
+ elif item.get("type") == "image_url":
102
+ image_url = item.get("image_url", {}).get("url")
103
+ else:
104
+ prompt_parts.append(f"User: {message.content}")
105
+ elif message.role == "assistant":
106
+ prompt_parts.append(f"Assistant: {message.content}")
107
+
108
+ # Add tool instructions to system prompt
109
+ if request.tools:
110
+ tool_prompt = format_tools_for_prompt(request.tools)
111
+ system_prompt += "\n" + tool_prompt
112
+
113
+ input_data["prompt"] = "\n".join(prompt_parts)
114
+ if system_prompt:
115
+ input_data["system_prompt"] = system_prompt
116
+ if image_url:
117
+ input_data["image"] = image_url
118
+
119
+ # Map other parameters
120
+ if request.temperature is not None:
121
+ input_data["temperature"] = request.temperature
122
+ if request.top_p is not None:
123
+ input_data["top_p"] = request.top_p
124
+ if request.max_tokens is not None:
125
+ # Replicate uses `max_new_tokens` or `max_tokens` depending on model
126
+ input_data["max_new_tokens"] = request.max_tokens
127
+
128
+ return input_data
129
+
130
+
131
+ async def stream_replicate_response(model_id: str, payload: dict):
132
+ """Generator for streaming Replicate responses."""
133
+ url = f"https://api.replicate.com/v1/models/{model_id}/predictions"
134
+ headers = {
135
+ "Authorization": f"Bearer {REPLICATE_API_TOKEN}",
136
+ "Content-Type": "application/json",
137
+ }
138
+
139
+ async with httpx.AsyncClient(timeout=300) as client:
140
+ # 1. Create the prediction and get the stream URL
141
+ payload["stream"] = True
142
+ try:
143
+ response = await client.post(url, headers=headers, json={"input": payload})
144
+ response.raise_for_status()
145
+ prediction = response.json()
146
+ stream_url = prediction.get("urls", {}).get("stream")
147
+
148
+ if not stream_url:
149
+ yield f"data: {json.dumps({'error': 'Failed to get stream URL'})}\n\n"
150
+ return
151
+ except httpx.HTTPStatusError as e:
152
+ yield f"data: {json.dumps({'error': str(e.response.text)})}\n\n"
153
+ return
154
+
155
+ # 2. Connect to the SSE stream
156
+ try:
157
+ async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}) as sse:
158
+ async for line in sse.aiter_lines():
159
+ if line.startswith("data:"):
160
+ event_data = line[len("data:"):].strip()
161
+ try:
162
+ data = json.loads(event_data)
163
+ # Format as OpenAI chunk
164
+ chunk = {
165
+ "id": prediction["id"],
166
+ "object": "chat.completion.chunk",
167
+ "created": int(time.time()),
168
+ "model": model_id,
169
+ "choices": [{
170
+ "index": 0,
171
+ "delta": {"content": data},
172
+ "finish_reason": None
173
+ }]
174
+ }
175
+ yield f"data: {json.dumps(chunk)}\n\n"
176
+ except json.JSONDecodeError:
177
+ continue # Skip non-json lines
178
+ except Exception as e:
179
+ yield f"data: {json.dumps({'error': f'Streaming error: {str(e)}'})}\n\n"
180
+
181
+ # Send the done signal
182
+ done_chunk = {
183
+ "id": prediction["id"],
184
+ "object": "chat.completion.chunk",
185
+ "created": int(time.time()),
186
+ "model": model_id,
187
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]
188
+ }
189
+ yield f"data: {json.dumps(done_chunk)}\n\n"
190
+ yield "data: [DONE]\n\n"
191
+
192
+
193
+ # --- API Endpoints ---
194
+
195
+ @app.get("/v1/models", response_model=ModelList)
196
+ async def list_models():
197
+ """Lists the available models that this compatibility layer supports."""
198
+ model_cards = [
199
+ ModelCard(id=model_name) for model_name in SUPPORTED_MODELS.keys()
200
+ ]
201
+ return ModelList(data=model_cards)
202
+
203
+ @app.post("/v1/chat/completions")
204
+ async def create_chat_completion(request: OpenAIChatCompletionRequest):
205
+ """Creates a chat completion, either streaming or synchronous."""
206
+ model_key = request.model
207
+ if model_key not in SUPPORTED_MODELS:
208
+ raise HTTPException(status_code=404, detail=f"Model not found. Supported models: {list(SUPPORTED_MODELS.keys())}")
209
+
210
+ replicate_model_id = SUPPORTED_MODELS[model_key]
211
+ replicate_input = prepare_replicate_input(request)
212
+
213
+ if request.stream:
214
+ return EventSourceResponse(stream_replicate_response(replicate_model_id, replicate_input))
215
+
216
+ # Synchronous request
217
+ url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions"
218
+ headers = {
219
+ "Authorization": f"Bearer {REPLICATE_API_TOKEN}",
220
+ "Content-Type": "application/json",
221
+ "Prefer": "wait=120" # Wait up to 120 seconds for a response
222
+ }
223
+
224
+ async with httpx.AsyncClient(timeout=150) as client:
225
+ try:
226
+ response = await client.post(url, headers=headers, json={"input": replicate_input})
227
+ response.raise_for_status()
228
+ prediction = response.json()
229
+
230
+ output = prediction.get("output", "")
231
+ if isinstance(output, list):
232
+ output = "".join(output)
233
+
234
+ # Check for tool call
235
+ try:
236
+ # A simple check if the output is a JSON for a tool call
237
+ tool_call_data = json.loads(output)
238
+ if tool_call_data.get("type") == "tool_call":
239
+ message_content = None
240
+ tool_calls = [{
241
+ "id": f"call_{int(time.time())}",
242
+ "type": "function",
243
+ "function": {
244
+ "name": tool_call_data["name"],
245
+ "arguments": json.dumps(tool_call_data["arguments"])
246
+ }
247
+ }]
248
+ else:
249
+ message_content = output
250
+ tool_calls = None
251
+ except (json.JSONDecodeError, TypeError):
252
+ message_content = output
253
+ tool_calls = None
254
+
255
+ # Format response in OpenAI format
256
+ completion_response = {
257
+ "id": prediction["id"],
258
+ "object": "chat.completion",
259
+ "created": int(time.time()),
260
+ "model": model_key,
261
+ "choices": [{
262
+ "index": 0,
263
+ "message": {
264
+ "role": "assistant",
265
+ "content": message_content,
266
+ "tool_calls": tool_calls,
267
+ },
268
+ "finish_reason": "stop" # Or map from Replicate if available
269
+ }],
270
+ "usage": { # Note: Replicate doesn't provide token usage in the same way
271
+ "prompt_tokens": 0,
272
+ "completion_tokens": 0,
273
+ "total_tokens": 0
274
+ }
275
+ }
276
+ return JSONResponse(content=completion_response)
277
+
278
+ except httpx.HTTPStatusError as e:
279
+ raise HTTPException(status_code=e.response.status_code, detail=e.response.text)