|
""" |
|
OpenAI handler module for creating clients and processing OpenAI Direct mode responses. |
|
This module encapsulates all OpenAI-specific logic that was previously in chat_api.py. |
|
""" |
|
import json |
|
import time |
|
import asyncio |
|
from typing import Dict, Any, AsyncGenerator |
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse |
|
import openai |
|
from google.auth.transport.requests import Request as AuthRequest |
|
|
|
from models import OpenAIRequest |
|
from config import VERTEX_REASONING_TAG |
|
import config as app_config |
|
from api_helpers import ( |
|
create_openai_error_response, |
|
openai_fake_stream_generator, |
|
StreamingReasoningProcessor |
|
) |
|
from message_processing import extract_reasoning_by_tags |
|
from credentials_manager import _refresh_auth |
|
|
|
|
|
class OpenAIDirectHandler: |
|
"""Handles OpenAI Direct mode operations including client creation and response processing.""" |
|
|
|
def __init__(self, credential_manager): |
|
self.credential_manager = credential_manager |
|
self.safety_settings = [ |
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "OFF"}, |
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "OFF"}, |
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "OFF"}, |
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "OFF"}, |
|
{"category": 'HARM_CATEGORY_CIVIC_INTEGRITY', "threshold": 'OFF'} |
|
] |
|
|
|
def create_openai_client(self, project_id: str, gcp_token: str, location: str = "global") -> openai.AsyncOpenAI: |
|
"""Create an OpenAI client configured for Vertex AI endpoint.""" |
|
endpoint_url = ( |
|
f"https://aiplatform.googleapis.com/v1beta1/" |
|
f"projects/{project_id}/locations/{location}/endpoints/openapi" |
|
) |
|
|
|
return openai.AsyncOpenAI( |
|
base_url=endpoint_url, |
|
api_key=gcp_token, |
|
) |
|
|
|
def prepare_openai_params(self, request: OpenAIRequest, model_id: str) -> Dict[str, Any]: |
|
"""Prepare parameters for OpenAI API call.""" |
|
params = { |
|
"model": model_id, |
|
"messages": [msg.model_dump(exclude_unset=True) for msg in request.messages], |
|
"temperature": request.temperature, |
|
"max_tokens": request.max_tokens, |
|
"top_p": request.top_p, |
|
"stream": request.stream, |
|
"stop": request.stop, |
|
"seed": request.seed, |
|
"n": request.n, |
|
} |
|
|
|
return {k: v for k, v in params.items() if v is not None} |
|
|
|
def prepare_extra_body(self) -> Dict[str, Any]: |
|
"""Prepare extra body parameters for OpenAI API call.""" |
|
return { |
|
"extra_body": { |
|
'google': { |
|
'safety_settings': self.safety_settings, |
|
'thought_tag_marker': VERTEX_REASONING_TAG |
|
} |
|
} |
|
} |
|
|
|
async def handle_streaming_response( |
|
self, |
|
openai_client: openai.AsyncOpenAI, |
|
openai_params: Dict[str, Any], |
|
openai_extra_body: Dict[str, Any], |
|
request: OpenAIRequest |
|
) -> StreamingResponse: |
|
"""Handle streaming responses for OpenAI Direct mode.""" |
|
if app_config.FAKE_STREAMING_ENABLED: |
|
print(f"INFO: OpenAI Fake Streaming (SSE Simulation) ENABLED for model '{request.model}'.") |
|
return StreamingResponse( |
|
openai_fake_stream_generator( |
|
openai_client=openai_client, |
|
openai_params=openai_params, |
|
openai_extra_body=openai_extra_body, |
|
request_obj=request, |
|
is_auto_attempt=False |
|
), |
|
media_type="text/event-stream" |
|
) |
|
else: |
|
print(f"INFO: OpenAI True Streaming ENABLED for model '{request.model}'.") |
|
return StreamingResponse( |
|
self._true_stream_generator(openai_client, openai_params, openai_extra_body, request), |
|
media_type="text/event-stream" |
|
) |
|
|
|
async def _true_stream_generator( |
|
self, |
|
openai_client: openai.AsyncOpenAI, |
|
openai_params: Dict[str, Any], |
|
openai_extra_body: Dict[str, Any], |
|
request: OpenAIRequest |
|
) -> AsyncGenerator[str, None]: |
|
"""Generate true streaming response.""" |
|
try: |
|
|
|
openai_params_for_stream = {**openai_params, "stream": True} |
|
stream_response = await openai_client.chat.completions.create( |
|
**openai_params_for_stream, |
|
extra_body=openai_extra_body |
|
) |
|
|
|
|
|
reasoning_processor = StreamingReasoningProcessor(VERTEX_REASONING_TAG) |
|
chunk_count = 0 |
|
has_sent_content = False |
|
|
|
async for chunk in stream_response: |
|
chunk_count += 1 |
|
try: |
|
chunk_as_dict = chunk.model_dump(exclude_unset=True, exclude_none=True) |
|
|
|
choices = chunk_as_dict.get('choices') |
|
if choices and isinstance(choices, list) and len(choices) > 0: |
|
delta = choices[0].get('delta') |
|
if delta and isinstance(delta, dict): |
|
|
|
if 'extra_content' in delta: |
|
del delta['extra_content'] |
|
|
|
content = delta.get('content', '') |
|
if content: |
|
|
|
|
|
processed_content, current_reasoning = reasoning_processor.process_chunk(content) |
|
|
|
|
|
|
|
|
|
|
|
|
|
chunks_to_send = [] |
|
|
|
|
|
if current_reasoning: |
|
reasoning_chunk = chunk_as_dict.copy() |
|
reasoning_chunk['choices'][0]['delta'] = {'reasoning_content': current_reasoning} |
|
chunks_to_send.append(reasoning_chunk) |
|
|
|
|
|
if processed_content: |
|
content_chunk = chunk_as_dict.copy() |
|
content_chunk['choices'][0]['delta'] = {'content': processed_content} |
|
chunks_to_send.append(content_chunk) |
|
has_sent_content = True |
|
|
|
|
|
for chunk_to_send in chunks_to_send: |
|
yield f"data: {json.dumps(chunk_to_send)}\n\n" |
|
else: |
|
|
|
yield f"data: {json.dumps(chunk_as_dict)}\n\n" |
|
else: |
|
|
|
yield f"data: {json.dumps(chunk_as_dict)}\n\n" |
|
|
|
except Exception as chunk_error: |
|
error_msg = f"Error processing OpenAI chunk for {request.model}: {str(chunk_error)}" |
|
print(f"ERROR: {error_msg}") |
|
if len(error_msg) > 1024: |
|
error_msg = error_msg[:1024] + "..." |
|
error_response = create_openai_error_response(500, error_msg, "server_error") |
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
yield "data: [DONE]\n\n" |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
remaining_content, remaining_reasoning = reasoning_processor.flush_remaining() |
|
|
|
|
|
if remaining_reasoning: |
|
|
|
reasoning_chunk = { |
|
"id": f"chatcmpl-{int(time.time())}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": request.model, |
|
"choices": [{"index": 0, "delta": {"reasoning_content": remaining_reasoning}, "finish_reason": None}] |
|
} |
|
yield f"data: {json.dumps(reasoning_chunk)}\n\n" |
|
|
|
|
|
if remaining_content: |
|
|
|
final_chunk = { |
|
"id": f"chatcmpl-{int(time.time())}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": request.model, |
|
"choices": [{"index": 0, "delta": {"content": remaining_content}, "finish_reason": None}] |
|
} |
|
yield f"data: {json.dumps(final_chunk)}\n\n" |
|
has_sent_content = True |
|
|
|
|
|
finish_chunk = { |
|
"id": f"chatcmpl-{int(time.time())}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": request.model, |
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] |
|
} |
|
yield f"data: {json.dumps(finish_chunk)}\n\n" |
|
|
|
yield "data: [DONE]\n\n" |
|
|
|
except Exception as stream_error: |
|
error_msg = str(stream_error) |
|
if len(error_msg) > 1024: |
|
error_msg = error_msg[:1024] + "..." |
|
error_msg_full = f"Error during OpenAI streaming for {request.model}: {error_msg}" |
|
print(f"ERROR: {error_msg_full}") |
|
error_response = create_openai_error_response(500, error_msg_full, "server_error") |
|
yield f"data: {json.dumps(error_response)}\n\n" |
|
yield "data: [DONE]\n\n" |
|
|
|
async def handle_non_streaming_response( |
|
self, |
|
openai_client: openai.AsyncOpenAI, |
|
openai_params: Dict[str, Any], |
|
openai_extra_body: Dict[str, Any], |
|
request: OpenAIRequest |
|
) -> JSONResponse: |
|
"""Handle non-streaming responses for OpenAI Direct mode.""" |
|
try: |
|
|
|
openai_params_non_stream = {**openai_params, "stream": False} |
|
response = await openai_client.chat.completions.create( |
|
**openai_params_non_stream, |
|
extra_body=openai_extra_body |
|
) |
|
response_dict = response.model_dump(exclude_unset=True, exclude_none=True) |
|
|
|
try: |
|
choices = response_dict.get('choices') |
|
if choices and isinstance(choices, list) and len(choices) > 0: |
|
message_dict = choices[0].get('message') |
|
if message_dict and isinstance(message_dict, dict): |
|
|
|
if 'extra_content' in message_dict: |
|
del message_dict['extra_content'] |
|
|
|
|
|
full_content = message_dict.get('content') |
|
actual_content = full_content if isinstance(full_content, str) else "" |
|
|
|
if actual_content: |
|
print(f"INFO: OpenAI Direct Non-Streaming - Applying tag extraction with fixed marker: '{VERTEX_REASONING_TAG}'") |
|
reasoning_text, actual_content = extract_reasoning_by_tags(actual_content, VERTEX_REASONING_TAG) |
|
message_dict['content'] = actual_content |
|
if reasoning_text: |
|
message_dict['reasoning_content'] = reasoning_text |
|
|
|
|
|
|
|
else: |
|
print(f"WARNING: OpenAI Direct Non-Streaming - No initial content found in message.") |
|
message_dict['content'] = "" |
|
|
|
except Exception as e_reasoning: |
|
print(f"WARNING: Error during non-streaming reasoning processing for model {request.model}: {e_reasoning}") |
|
|
|
return JSONResponse(content=response_dict) |
|
|
|
except Exception as e: |
|
error_msg = f"Error calling OpenAI client for {request.model}: {str(e)}" |
|
print(f"ERROR: {error_msg}") |
|
return JSONResponse( |
|
status_code=500, |
|
content=create_openai_error_response(500, error_msg, "server_error") |
|
) |
|
|
|
async def process_request(self, request: OpenAIRequest, base_model_name: str): |
|
"""Main entry point for processing OpenAI Direct mode requests.""" |
|
print(f"INFO: Using OpenAI Direct Path for model: {request.model}") |
|
|
|
|
|
rotated_credentials, rotated_project_id = self.credential_manager.get_credentials() |
|
|
|
if not rotated_credentials or not rotated_project_id: |
|
error_msg = "OpenAI Direct Mode requires GCP credentials, but none were available or loaded successfully." |
|
print(f"ERROR: {error_msg}") |
|
return JSONResponse( |
|
status_code=500, |
|
content=create_openai_error_response(500, error_msg, "server_error") |
|
) |
|
|
|
print(f"INFO: [OpenAI Direct Path] Using credentials for project: {rotated_project_id}") |
|
gcp_token = _refresh_auth(rotated_credentials) |
|
|
|
if not gcp_token: |
|
error_msg = f"Failed to obtain valid GCP token for OpenAI client (Project: {rotated_project_id})." |
|
print(f"ERROR: {error_msg}") |
|
return JSONResponse( |
|
status_code=500, |
|
content=create_openai_error_response(500, error_msg, "server_error") |
|
) |
|
|
|
|
|
openai_client = self.create_openai_client(rotated_project_id, gcp_token) |
|
model_id = f"google/{base_model_name}" |
|
openai_params = self.prepare_openai_params(request, model_id) |
|
openai_extra_body = self.prepare_extra_body() |
|
|
|
|
|
if request.stream: |
|
return await self.handle_streaming_response( |
|
openai_client, openai_params, openai_extra_body, request |
|
) |
|
else: |
|
return await self.handle_non_streaming_response( |
|
openai_client, openai_params, openai_extra_body, request |
|
) |