File size: 7,341 Bytes
fae0e6c a1207af fae0e6c a1207af fae0e6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import asyncio
import json
import os
import uuid
from typing import AsyncIterator, Dict, Any
import aiohttp
import logging
logger = logging.getLogger(__name__)
class SSEClient:
"""Async SSE client for streaming chat API requests"""
def __init__(self):
self.url = os.getenv("API_ENDPOINT")
self.headers = {
'Content-Type': 'application/json',
'User-Agent': 'HuggingFace-Gradio-Demo'
}
async def stream_chat(self, query: str,
deep_thinking_mode: bool = False,
search_before_planning: bool = False,
debug: bool = False,
chat_id: str = None) -> AsyncIterator[Dict[str, Any]]:
"""
Async request to SSE interface and return streaming data with event parsing
Args:
query: User query content
deep_thinking_mode: Whether to enable deep thinking mode, default False
search_before_planning: Whether to search before planning, default False
debug: Whether to enable debug mode, default False
chat_id: Chat ID, will be auto-generated if not provided
Yields:
Dict[str, Any]: SSE event data with 'event' and 'data' fields
"""
if chat_id is None:
chat_id = self._generate_chat_id()
# Build request data
data = {
"messages": [{
"id": chat_id,
"role": "user",
"type": "text",
"content": query
}],
"deep_thinking_mode": deep_thinking_mode,
"search_before_planning": search_before_planning,
"debug": debug,
"chatId": chat_id
}
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=None) # No timeout limit
) as session:
try:
async with session.post(
self.url,
headers=self.headers,
json=data
) as response:
if response.status != 200:
raise Exception(f"Request failed with status code: {response.status}")
# Read SSE stream and parse events
current_event = None
async for line in response.content:
line = line.decode('utf-8').strip()
if line:
if line.startswith('event: '):
# Parse event type
current_event = line[7:] # Remove "event: " prefix
elif line.startswith('data: '):
# Parse data content
data_content = line[6:] # Remove "data: " prefix
if data_content and data_content != '[DONE]':
# Yield structured event data
yield {
'event': current_event or 'message',
'data': data_content
}
# Reset event for next message
current_event = None
elif line == '':
# Empty line indicates end of event, reset current_event
current_event = None
else:
# Handle other formats or raw data
yield {
'event': current_event or 'data',
'data': line
}
current_event = None
except asyncio.CancelledError:
# Handle cancellation
raise
except Exception as e:
raise Exception(f"SSE request error: {str(e)}")
def _generate_chat_id(self) -> str:
"""Generate chat ID"""
return str(uuid.uuid4()).replace('-', '')[:21]
async def stream_chat_parsed(self, query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
"""
Async request to SSE interface and return parsed JSON data with event structure
Args:
query: User query content
**kwargs: Other parameters passed to stream_chat
Yields:
Dict[str, Any]: Event data with 'event' and 'data' fields, where 'data' contains parsed JSON
"""
async for event_data in self.stream_chat(query, **kwargs):
try:
# Try to parse the data field as JSON
parsed_data = json.loads(event_data['data'])
yield {
'event': event_data['event'],
'data': parsed_data
}
except json.JSONDecodeError:
# If data is not valid JSON, keep original data
yield event_data
except (KeyError, TypeError):
# If event_data doesn't have expected structure, skip
continue
# Convenience functions
async def request_sse_stream(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
"""
Convenience function: Async request to SSE interface and return raw event data
Args:
query: User query content
**kwargs: Other parameters
Yields:
Dict[str, Any]: Raw event data with 'event' and 'data' fields (data as string)
"""
client = SSEClient()
async for event_data in client.stream_chat(query, **kwargs):
yield event_data
async def request_sse_stream_parsed(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
"""
Convenience function: Async request to SSE interface and return structured event data
Args:
query: User query content
**kwargs: Other parameters
Yields:
Dict[str, Any]: Event data with 'event' and 'data' fields
"""
client = SSEClient()
async for event_data in client.stream_chat_parsed(query, **kwargs):
yield event_data
async def stop_chat(chat_id: str):
url = f"{os.getenv('STOP_CHAT_API_ENDPOINT')}"
async with aiohttp.ClientSession() as session:
async with session.post(url, json={"chatId": chat_id}) as response:
if response.status != 200:
logger.error(f"Request failed with status code: {response.status}")
raise Exception(f"Request failed with status code: {response.status}")
return await response.json()
# Example usage
async def main():
"""Example usage method"""
query = "Hello"
print("=== SSE Event Stream ===")
async for event_data in request_sse_stream_parsed(query):
event_type = event_data.get('event', 'unknown')
data_content = event_data.get('data', {})
print(f"Event: {event_type}")
print(f"Data: {data_content}")
print("-" * 40)
if __name__ == "__main__":
asyncio.run(main())
|