test1 / api_server.py
Doleeee's picture
conversation -> history ๋ณ€๊ฒฝ #31
55ff7de
import json
import sys
import threading
from queue import Empty, Queue
from threading import Thread
from typing import List, Optional
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from pipeline import pipeline as run_pipeline
from persona.make_persona import make_persona
app = FastAPI()
class _ThreadStdoutProxy:
def __init__(self, target):
self._target = target
self._handlers = {}
self._lock = threading.RLock()
self.encoding = getattr(target, "encoding", "utf-8")
self.errors = getattr(target, "errors", None)
def register(self, thread_id: int, handler) -> None:
with self._lock:
self._handlers[thread_id] = handler
def unregister(self, thread_id: int) -> None:
with self._lock:
self._handlers.pop(thread_id, None)
def _resolve(self):
thread_id = threading.get_ident()
with self._lock:
return self._handlers.get(thread_id), self._target
def write(self, data):
handler, target = self._resolve()
if handler:
return handler.write(data)
return target.write(data)
def flush(self):
handler, target = self._resolve()
if handler:
handler.flush()
return target.flush()
def isatty(self):
return getattr(self._target, "isatty", lambda: False)()
def fileno(self):
return self._target.fileno()
def writable(self):
return True
def __getattr__(self, name):
return getattr(self._target, name)
class _QueueingStdoutTee:
def __init__(self, target, event_queue: Queue):
self._target = target
self._event_queue = event_queue
def write(self, data):
written = self._target.write(data)
if data:
self._event_queue.put({"type": "stdout", "message": data})
return written
def flush(self):
self._target.flush()
_stdout_proxy = _ThreadStdoutProxy(sys.stdout)
sys.stdout = _stdout_proxy
class PersonaRequest(BaseModel):
info: str
stream: bool = True
PERSONA_STATUS_MESSAGES = [
"์ธ๋ฌผ ์ •๋ณด ์ˆ˜์ง‘ ์ค‘...",
"์›น ๊ฒ€์ƒ‰์„ ํ†ตํ•ด ๋ฐฐ๊ฒฝ ์กฐ์‚ฌ ์ค‘...",
"๊ธˆ์œต ์‚ฌ๊ณ  ๋ฐฉ์‹ ๋ถ„์„ ์ค‘...",
"๋ฐ์ดํ„ฐ ๋ถ„์„ ์ ‘๊ทผ๋ฒ• ํ‰๊ฐ€ ์ค‘...",
"๋‹ต๋ณ€ ์Šคํƒ€์ผ ํŠน์„ฑ ํŒŒ์•… ์ค‘...",
"ํ•ต์‹ฌ ํˆฌ์ž ์›์น™ ์ถ”์ถœ ์ค‘...",
"๋Œ€ํ‘œ ์–ด๋ก ์ •๋ฆฌ ์ค‘...",
"ํŽ˜๋ฅด์†Œ๋‚˜ ํ”„๋กœํ•„ ๊ตฌ์„ฑ ์ค‘...",
"์ตœ์ข… ๊ฒ€์ฆ ๋ฐ ์ €์žฅ ์ค€๋น„ ์ค‘...",
]
def _build_persona_payload(persona) -> dict:
return {
"type": "result",
"name": persona.name,
"full_name": persona.full_name,
"summary": persona.summary,
"financial_mindset": persona.financial_mindset,
"data_analysis_approach": persona.data_analysis_approach,
"response_style": persona.response_style,
"key_principles": persona.key_principles,
"famous_quotes": getattr(persona, "famous_quotes", None),
}
@app.post("/persona/")
async def create_persona(request: PersonaRequest):
info = (request.info or "").strip()
stream = request.stream
if not info:
return JSONResponse(status_code=400, content={"error": "info ํ•„๋“œ๊ฐ€ ๋น„์–ด ์žˆ์Šต๋‹ˆ๋‹ค."})
if not stream:
try:
persona = make_persona(info)
except Exception as exc:
return JSONResponse(status_code=500, content={"error": str(exc)})
if persona is None:
return JSONResponse(status_code=500, content={"error": "ํŽ˜๋ฅด์†Œ๋‚˜ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."})
return JSONResponse(content=persona.model_dump())
def event_stream():
event_queue: Queue = Queue()
def status_sender():
import asyncio
async def send_status():
for i, message in enumerate(PERSONA_STATUS_MESSAGES[:-1]): # ๋งˆ์ง€๋ง‰ ๋ฉ”์‹œ์ง€๋Š” ์™„๋ฃŒ ์‹œ์ ์— ์‚ฌ์šฉ
event_queue.put({"type": "status", "message": message})
await asyncio.sleep(8)
# ๋น„๋™๊ธฐ ์ด๋ฒคํŠธ ๋ฃจํ”„์—์„œ ์‹คํ–‰
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(send_status())
def worker():
thread_id = threading.get_ident()
_stdout_proxy.register(thread_id, _QueueingStdoutTee(_stdout_proxy._target, event_queue))
try:
# status ๋ฉ”์‹œ์ง€ ์ „์†ก ์Šค๋ ˆ๋“œ ์‹œ์ž‘
status_thread = Thread(target=status_sender, daemon=True)
status_thread.start()
persona = make_persona(info)
if persona is None:
event_queue.put({"type": "error", "message": "ํŽ˜๋ฅด์†Œ๋‚˜ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."})
else:
event_queue.put(_build_persona_payload(persona))
except Exception as exc:
event_queue.put({"type": "error", "message": str(exc)})
finally:
_stdout_proxy.unregister(thread_id)
event_queue.put({"type": "done"})
yield _sse({"type": "status", "message": "ํŽ˜๋ฅด์†Œ๋‚˜ ์ƒ์„ฑ ์ค€๋น„ ์ค‘..."})
Thread(target=worker, daemon=True).start()
done = False
while not done:
try:
event = event_queue.get(timeout=0.2)
except Empty:
continue
yield _sse(jsonable_encoder(event))
if event.get("type") == "done":
done = True
headers = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
class QueryRequest(BaseModel):
query: str
history: List["ChatMessage"] = Field(default_factory=list)
stream: bool = True
persona_name: Optional[str] = None
class ChatMessage(BaseModel):
role: str
content: str
def _normalize_chat_role(role: str) -> str:
role = (role or "").strip().lower()
return role
def _normalize_history_input(history_input):
history = []
for message in history_input or []:
if isinstance(message, ChatMessage):
role = _normalize_chat_role(message.role)
content = (message.content or "").strip()
elif isinstance(message, dict):
role = _normalize_chat_role(message.get("role", ""))
content = (message.get("content", "") or "").strip()
else:
continue
if not role or not content:
continue
history.append({"role": role, "content": content})
return history
def _sse(payload: dict) -> str:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
def _build_result_payload(result, stdout: str = "") -> dict:
payload = {
"type": "result",
"query": result.query,
"ticker": result.ticker,
"analysis_type": result.analysis_type,
"data_context": result.data_context,
"llm_response": result.llm_response,
"timestamp": getattr(result, "timestamp", None),
}
if stdout:
payload["stdout"] = stdout
return payload
@app.post("/analyze/")
async def analyze(request: QueryRequest):
query = (request.query or "").strip()
history = _normalize_history_input(request.history)
stream = request.stream
persona_name = (request.persona_name or "").strip() or None
if not query:
return JSONResponse(status_code=400, content={"error": "query ํ•„๋“œ๊ฐ€ ๋น„์–ด ์žˆ์Šต๋‹ˆ๋‹ค."})
if not stream:
stdout_messages = []
class _ListStdoutTee:
def __init__(self, target):
self._target = target
def write(self, data):
written = self._target.write(data)
if data:
stdout_messages.append(data)
return written
def flush(self):
self._target.flush()
thread_id = threading.get_ident()
_stdout_proxy.register(thread_id, _ListStdoutTee(_stdout_proxy._target))
try:
result = run_pipeline(
query,
history=history,
persona_name=persona_name,
status_callback=None,
stream_callback=None,
stream=False,
)
finally:
_stdout_proxy.unregister(thread_id)
return JSONResponse(
content=jsonable_encoder(_build_result_payload(result, stdout="".join(stdout_messages)))
)
def event_stream():
event_queue: Queue = Queue()
def on_status(message: str):
event_queue.put({"type": "status", "message": message})
def on_delta(delta: str):
if stream:
event_queue.put({"type": "delta", "delta": delta})
def worker():
thread_id = threading.get_ident()
_stdout_proxy.register(thread_id, _QueueingStdoutTee(_stdout_proxy._target, event_queue))
try:
result = run_pipeline(
query,
history=history,
persona_name=persona_name,
status_callback=on_status,
stream_callback=on_delta if stream else None,
stream=stream,
)
event_queue.put(_build_result_payload(result))
except Exception as exc:
event_queue.put({"type": "error", "message": str(exc)})
finally:
_stdout_proxy.unregister(thread_id)
event_queue.put({"type": "done"})
yield _sse({"type": "status", "message": "์š”์ฒญ ์ˆ˜์‹ . ๋ถ„์„ ์ค€๋น„ ์ค‘..."})
Thread(target=worker, daemon=True).start()
done = False
while not done:
try:
event = event_queue.get(timeout=0.2)
except Empty:
continue
yield _sse(jsonable_encoder(event))
if event.get("type") == "done":
done = True
headers = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)