|
|
|
|
|
import os |
|
|
import httpx |
|
|
import json |
|
|
import time |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import List, Dict, Any, Optional, Union, Literal |
|
|
from dotenv import load_dotenv |
|
|
from sse_starlette.sse import EventSourceResponse |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") |
|
|
if not REPLICATE_API_TOKEN: |
|
|
raise ValueError("REPLICATE_API_TOKEN environment variable not set.") |
|
|
|
|
|
|
|
|
app = FastAPI(title="Replicate to OpenAI Compatibility Layer", version="9.0.0 (Definitive Streaming Fix)") |
|
|
|
|
|
|
|
|
class ModelCard(BaseModel): |
|
|
id: str; object: str = "model"; created: int = Field(default_factory=lambda: int(time.time())); owned_by: str = "replicate" |
|
|
class ModelList(BaseModel): |
|
|
object: str = "list"; data: List[ModelCard] = [] |
|
|
class ChatMessage(BaseModel): |
|
|
role: Literal["system", "user", "assistant", "tool"]; content: Union[str, List[Dict[str, Any]]] |
|
|
class OpenAIChatCompletionRequest(BaseModel): |
|
|
model: str; messages: List[ChatMessage]; temperature: Optional[float] = 0.7; top_p: Optional[float] = 1.0; max_tokens: Optional[int] = None; stream: Optional[bool] = False |
|
|
|
|
|
|
|
|
SUPPORTED_MODELS = { |
|
|
"llama3-8b-instruct": "meta/meta-llama-3-8b-instruct", |
|
|
"claude-4.5-haiku": "anthropic/claude-4.5-haiku", |
|
|
"claude-4.5-sonnet": "anthropic/claude-4.5-sonnet", |
|
|
"llava-13b": "yorickvp/llava-13b:e272157381e2a3bf12df3a8edd1f38d1dbd736bbb7437277c8b34175f8fce358" |
|
|
} |
|
|
|
|
|
|
|
|
def prepare_replicate_input(request: OpenAIChatCompletionRequest) -> Dict[str, Any]: |
|
|
""" |
|
|
Formats the input for Replicate's API, flattening the message history into a |
|
|
single 'prompt' string and handling images separately. |
|
|
""" |
|
|
payload = {} |
|
|
prompt_parts = [] |
|
|
system_prompt = None |
|
|
image_input = None |
|
|
|
|
|
for msg in request.messages: |
|
|
if msg.role == "system": |
|
|
system_prompt = str(msg.content) |
|
|
elif msg.role == "assistant": |
|
|
prompt_parts.append(f"Assistant: {msg.content}") |
|
|
elif msg.role == "user": |
|
|
user_text_content = "" |
|
|
if isinstance(msg.content, list): |
|
|
for item in msg.content: |
|
|
if item.get("type") == "text": |
|
|
user_text_content += item.get("text", "") |
|
|
elif item.get("type") == "image_url": |
|
|
image_url_data = item.get("image_url", {}) |
|
|
image_input = image_url_data.get("url") |
|
|
else: |
|
|
user_text_content = str(msg.content) |
|
|
prompt_parts.append(f"User: {user_text_content}") |
|
|
|
|
|
prompt_parts.append("Assistant:") |
|
|
payload["prompt"] = "\n\n".join(prompt_parts) |
|
|
|
|
|
if system_prompt: |
|
|
payload["system_prompt"] = system_prompt |
|
|
if image_input: |
|
|
payload["image"] = image_input |
|
|
|
|
|
if request.max_tokens: payload["max_new_tokens"] = request.max_tokens |
|
|
if request.temperature: payload["temperature"] = request.temperature |
|
|
if request.top_p: payload["top_p"] = request.top_p |
|
|
|
|
|
return payload |
|
|
|
|
|
async def stream_replicate_sse(replicate_model_id: str, input_payload: dict): |
|
|
"""Handles the full streaming lifecycle with correct whitespace preservation.""" |
|
|
url = f"https://api.replicate.com/v1/models/{replicate_model_id}/predictions" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json"} |
|
|
|
|
|
async with httpx.AsyncClient(timeout=60.0) as client: |
|
|
try: |
|
|
response = await client.post(url, headers=headers, json={"input": input_payload, "stream": True}) |
|
|
response.raise_for_status() |
|
|
prediction = response.json() |
|
|
stream_url = prediction.get("urls", {}).get("stream") |
|
|
prediction_id = prediction.get("id", "stream-unknown") |
|
|
if not stream_url: |
|
|
yield json.dumps({'error': {'message': 'Model did not return a stream URL.'}}) |
|
|
return |
|
|
except httpx.HTTPStatusError as e: |
|
|
error_details = e.response.text |
|
|
try: |
|
|
error_json = e.response.json() |
|
|
error_details = error_json.get("detail", error_details) |
|
|
except json.JSONDecodeError: pass |
|
|
yield json.dumps({'error': {'message': f'Upstream Error: {error_details}', 'type': 'replicate_error'}}) |
|
|
return |
|
|
|
|
|
try: |
|
|
async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=None) as sse: |
|
|
current_event = None |
|
|
async for line in sse.aiter_lines(): |
|
|
if not line: |
|
|
continue |
|
|
if line.startswith("event:"): |
|
|
current_event = line[len("event:"):].strip() |
|
|
elif line.startswith("data:"): |
|
|
|
|
|
raw_data = line[5:] |
|
|
|
|
|
|
|
|
|
|
|
if raw_data.startswith(" "): |
|
|
data_content = raw_data[1:] |
|
|
else: |
|
|
data_content = raw_data |
|
|
|
|
|
if current_event == "output": |
|
|
if not data_content: |
|
|
continue |
|
|
|
|
|
content_token = "" |
|
|
try: |
|
|
|
|
|
content_token = json.loads(data_content) |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
|
|
|
content_token = data_content |
|
|
|
|
|
|
|
|
chunk = { |
|
|
"choices": [{ |
|
|
"delta": {"content": content_token}, |
|
|
"finish_reason": None, |
|
|
"index": 0, |
|
|
"logprobs": None, |
|
|
"native_finish_reason": None |
|
|
}], |
|
|
"created": int(time.time()), |
|
|
"id": f"gen-{int(time.time())}-{prediction_id[-12:]}", |
|
|
"model": replicate_model_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate" |
|
|
} |
|
|
|
|
|
yield json.dumps(chunk) |
|
|
|
|
|
elif current_event == "done": |
|
|
|
|
|
usage_chunk = { |
|
|
"choices": [{ |
|
|
"delta": {}, |
|
|
"finish_reason": None, |
|
|
"index": 0, |
|
|
"logprobs": None, |
|
|
"native_finish_reason": None |
|
|
}], |
|
|
"created": int(time.time()), |
|
|
"id": f"gen-{int(time.time())}-{prediction_id[-12:]}", |
|
|
"model": replicate_model_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate", |
|
|
"usage": { |
|
|
"cache_discount": 0, |
|
|
"completion_tokens": 0, |
|
|
"completion_tokens_details": {"image_tokens": 0, "reasoning_tokens": 0}, |
|
|
"cost": 0, |
|
|
"cost_details": { |
|
|
"upstream_inference_completions_cost": 0, |
|
|
"upstream_inference_cost": None, |
|
|
"upstream_inference_prompt_cost": 0 |
|
|
}, |
|
|
"input_tokens": 0, |
|
|
"is_byok": False, |
|
|
"prompt_tokens": 0, |
|
|
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0}, |
|
|
"total_tokens": 0 |
|
|
} |
|
|
} |
|
|
yield json.dumps(usage_chunk) |
|
|
|
|
|
|
|
|
final_chunk = { |
|
|
"choices": [{ |
|
|
"delta": {}, |
|
|
"finish_reason": "stop", |
|
|
"index": 0, |
|
|
"logprobs": None, |
|
|
"native_finish_reason": "end_turn" |
|
|
}], |
|
|
"created": int(time.time()), |
|
|
"id": f"gen-{int(time.time())}-{prediction_id[-12:]}", |
|
|
"model": replicate_model_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"provider": "Anthropic" if "anthropic" in replicate_model_id else "Replicate" |
|
|
} |
|
|
yield json.dumps(final_chunk) |
|
|
break |
|
|
except httpx.ReadTimeout: |
|
|
yield json.dumps({'error': {'message': 'Stream timed out.', 'type': 'timeout_error'}}) |
|
|
return |
|
|
|
|
|
|
|
|
yield "[DONE]" |
|
|
|
|
|
|
|
|
@app.get("/v1/models") |
|
|
async def list_models(): |
|
|
return ModelList(data=[ModelCard(id=k) for k in SUPPORTED_MODELS.keys()]) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def create_chat_completion(request: OpenAIChatCompletionRequest): |
|
|
if request.model not in SUPPORTED_MODELS: |
|
|
raise HTTPException(status_code=404, detail=f"Model not found. Available models: {list(SUPPORTED_MODELS.keys())}") |
|
|
|
|
|
replicate_input = prepare_replicate_input(request) |
|
|
|
|
|
if request.stream: |
|
|
return EventSourceResponse(stream_replicate_sse(SUPPORTED_MODELS[request.model], replicate_input), media_type="text/event-stream") |
|
|
|
|
|
|
|
|
url = f"https://api.replicate.com/v1/models/{SUPPORTED_MODELS[request.model]}/predictions" |
|
|
headers = {"Authorization": f"Bearer {REPLICATE_API_TOKEN}", "Content-Type": "application/json", "Prefer": "wait=120"} |
|
|
async with httpx.AsyncClient() as client: |
|
|
try: |
|
|
resp = await client.post(url, headers=headers, json={"input": replicate_input}, timeout=130.0) |
|
|
resp.raise_for_status() |
|
|
pred = resp.json() |
|
|
output = "".join(pred.get("output", [])) |
|
|
return { |
|
|
"id": pred.get("id"), "object": "chat.completion", "created": int(time.time()), "model": request.model, |
|
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": output}, "finish_reason": "stop"}], |
|
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} |
|
|
} |
|
|
except httpx.HTTPStatusError as e: |
|
|
raise HTTPException(status_code=e.response.status_code, detail=f"Error from Replicate API: {e.response.text}") |
|
|
|