File size: 7,341 Bytes
fae0e6c
 
 
 
 
 
a1207af
 
 
fae0e6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1207af
 
 
 
 
 
 
 
 
fae0e6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import asyncio
import json
import os
import uuid
from typing import AsyncIterator, Dict, Any
import aiohttp
import logging

logger = logging.getLogger(__name__)


class SSEClient:
    """Async SSE client for streaming chat API requests"""
    
    def __init__(self):
        self.url = os.getenv("API_ENDPOINT")
        self.headers = {
            'Content-Type': 'application/json',
            'User-Agent': 'HuggingFace-Gradio-Demo'
        }

    async def stream_chat(self, query: str, 
                         deep_thinking_mode: bool = False,
                         search_before_planning: bool = False,
                         debug: bool = False,
                         chat_id: str = None) -> AsyncIterator[Dict[str, Any]]:
        """
        Async request to SSE interface and return streaming data with event parsing
        
        Args:
            query: User query content
            deep_thinking_mode: Whether to enable deep thinking mode, default False
            search_before_planning: Whether to search before planning, default False
            debug: Whether to enable debug mode, default False
            chat_id: Chat ID, will be auto-generated if not provided
            
        Yields:
            Dict[str, Any]: SSE event data with 'event' and 'data' fields
        """
        if chat_id is None:
            chat_id = self._generate_chat_id()
            
        # Build request data
        data = {
            "messages": [{
                "id": chat_id,
                "role": "user",
                "type": "text",
                "content": query
            }],
            "deep_thinking_mode": deep_thinking_mode,
            "search_before_planning": search_before_planning,
            "debug": debug,
            "chatId": chat_id
        }
        
        async with aiohttp.ClientSession(
            timeout=aiohttp.ClientTimeout(total=None)  # No timeout limit
        ) as session:
            try:
                async with session.post(
                    self.url,
                    headers=self.headers,
                    json=data
                ) as response:
                    if response.status != 200:
                        raise Exception(f"Request failed with status code: {response.status}")
                    
                    # Read SSE stream and parse events
                    current_event = None
                    
                    async for line in response.content:
                        line = line.decode('utf-8').strip()
                        if line:
                            if line.startswith('event: '):
                                # Parse event type
                                current_event = line[7:]  # Remove "event: " prefix
                            elif line.startswith('data: '):
                                # Parse data content
                                data_content = line[6:]  # Remove "data: " prefix
                                if data_content and data_content != '[DONE]':
                                    # Yield structured event data
                                    yield {
                                        'event': current_event or 'message',
                                        'data': data_content
                                    }
                                    # Reset event for next message
                                    current_event = None
                            elif line == '':
                                # Empty line indicates end of event, reset current_event
                                current_event = None
                            else:
                                # Handle other formats or raw data
                                yield {
                                    'event': current_event or 'data',
                                    'data': line
                                }
                                current_event = None
                                
            except asyncio.CancelledError:
                # Handle cancellation
                raise
            except Exception as e:
                raise Exception(f"SSE request error: {str(e)}")

    def _generate_chat_id(self) -> str:
        """Generate chat ID"""
        return str(uuid.uuid4()).replace('-', '')[:21]

    async def stream_chat_parsed(self, query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
        """
        Async request to SSE interface and return parsed JSON data with event structure
        
        Args:
            query: User query content
            **kwargs: Other parameters passed to stream_chat
            
        Yields:
            Dict[str, Any]: Event data with 'event' and 'data' fields, where 'data' contains parsed JSON
        """
        async for event_data in self.stream_chat(query, **kwargs):
            try:
                # Try to parse the data field as JSON
                parsed_data = json.loads(event_data['data'])
                yield {
                    'event': event_data['event'],
                    'data': parsed_data
                }
            except json.JSONDecodeError:
                # If data is not valid JSON, keep original data
                yield event_data
            except (KeyError, TypeError):
                # If event_data doesn't have expected structure, skip
                continue


# Convenience functions
async def request_sse_stream(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
    """
    Convenience function: Async request to SSE interface and return raw event data
    
    Args:
        query: User query content
        **kwargs: Other parameters
        
    Yields:
        Dict[str, Any]: Raw event data with 'event' and 'data' fields (data as string)
    """
    client = SSEClient()
    async for event_data in client.stream_chat(query, **kwargs):
        yield event_data


async def request_sse_stream_parsed(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
    """
    Convenience function: Async request to SSE interface and return structured event data
    
    Args:
        query: User query content
        **kwargs: Other parameters
        
    Yields:
        Dict[str, Any]: Event data with 'event' and 'data' fields
    """
    client = SSEClient()
    async for event_data in client.stream_chat_parsed(query, **kwargs):
        yield event_data


async def stop_chat(chat_id: str):
    url = f"{os.getenv('STOP_CHAT_API_ENDPOINT')}"
    async with aiohttp.ClientSession() as session:
        async with session.post(url, json={"chatId": chat_id}) as response:
            if response.status != 200:
                logger.error(f"Request failed with status code: {response.status}")
                raise Exception(f"Request failed with status code: {response.status}")
            return await response.json()

# Example usage
async def main():
    """Example usage method"""
    query = "Hello"
    
    print("=== SSE Event Stream ===")
    async for event_data in request_sse_stream_parsed(query):
        event_type = event_data.get('event', 'unknown')
        data_content = event_data.get('data', {})
        print(f"Event: {event_type}")
        print(f"Data: {data_content}")
        print("-" * 40)


if __name__ == "__main__":
    asyncio.run(main())