Spaces:
Sleeping
Sleeping
Auto commit at 23-2025-08 10:05:36
Browse files- lily_llm_api/api/__init__.py +3 -0
- lily_llm_api/api/routers/__init__.py +3 -0
- lily_llm_api/api/routers/advanced_context_router.py +302 -0
- lily_llm_api/api/routers/context_router.py +273 -0
- lily_llm_api/api/routers/document_router.py +434 -0
- lily_llm_api/api/routers/generation_router.py +128 -0
- lily_llm_api/api/routers/lora_router.py +223 -0
- lily_llm_api/api/routers/model_router.py +56 -0
- lily_llm_api/api/routers/multimodal_rag_router.py +567 -0
- lily_llm_api/api/routers/ocr_router.py +404 -0
- lily_llm_api/api/routers/user_memory_router.py +341 -0
- lily_llm_api/app.py +0 -0
- lily_llm_api/app_v2.py +0 -0
- lily_llm_api/app_v2_modular.py +34 -0
- lily_llm_api/core/__init__.py +3 -0
- lily_llm_api/core/app_factory.py +125 -0
- lily_llm_api/models/back/configuration.py +0 -125
- lily_llm_api/models/back/modeling.py +0 -973
- lily_llm_api/models/schemas.py +184 -0
- lily_llm_api/services/__init__.py +3 -0
- lily_llm_api/services/generation_service.py +583 -0
- lily_llm_api/services/model_service.py +91 -0
- lily_llm_api/utils/__init__.py +3 -0
- lily_llm_api/utils/lora_utils.py +124 -0
- lily_llm_api/utils/system_utils.py +65 -0
- lily_llm_core/document_processor.py +16 -16
- run_server.py +1 -1
- run_server_v2.py +1 -1
lily_llm_api/api/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API package for Lily LLM API
|
| 3 |
+
"""
|
lily_llm_api/api/routers/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Routers package for Lily LLM API
|
| 3 |
+
"""
|
lily_llm_api/api/routers/advanced_context_router.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Advanced context management router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Form
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
@router.post("/context/set-system-prompt")
|
| 13 |
+
async def set_system_prompt(prompt: str = Form(...)):
|
| 14 |
+
"""์์คํ
ํ๋กฌํํธ ์ค์ """
|
| 15 |
+
try:
|
| 16 |
+
try:
|
| 17 |
+
from lily_llm_core.context_manager import context_manager
|
| 18 |
+
context_manager.set_system_prompt(prompt)
|
| 19 |
+
return {
|
| 20 |
+
"success": True,
|
| 21 |
+
"message": "์์คํ
ํ๋กฌํํธ๊ฐ ์ค์ ๋์์ต๋๋ค.",
|
| 22 |
+
"prompt_length": len(prompt)
|
| 23 |
+
}
|
| 24 |
+
except ImportError:
|
| 25 |
+
return {"success": False, "error": "Context manager not available"}
|
| 26 |
+
except Exception as e:
|
| 27 |
+
logger.error(f"โ ์์คํ
ํ๋กฌํํธ ์ค์ ์คํจ: {e}")
|
| 28 |
+
return {"success": False, "error": str(e)}
|
| 29 |
+
|
| 30 |
+
@router.post("/context/add-message")
|
| 31 |
+
async def add_context_message(
|
| 32 |
+
role: str = Form(...), # 'user' ๋๋ 'assistant'
|
| 33 |
+
content: str = Form(...),
|
| 34 |
+
message_id: str = Form(None),
|
| 35 |
+
metadata: str = Form("{}") # JSON ๋ฌธ์์ด
|
| 36 |
+
):
|
| 37 |
+
"""์ปจํ
์คํธ์ ๋ฉ์์ง ์ถ๊ฐ"""
|
| 38 |
+
try:
|
| 39 |
+
try:
|
| 40 |
+
from lily_llm_core.context_manager import context_manager
|
| 41 |
+
metadata_dict = json.loads(metadata) if metadata else {}
|
| 42 |
+
|
| 43 |
+
if role == "user":
|
| 44 |
+
msg_id = context_manager.add_user_message(content, message_id, metadata_dict)
|
| 45 |
+
elif role == "assistant":
|
| 46 |
+
msg_id = context_manager.add_assistant_message(content, message_id, metadata_dict)
|
| 47 |
+
else:
|
| 48 |
+
return {"success": False, "error": "์๋ชป๋ ์ญํ ์
๋๋ค. 'user' ๋๋ 'assistant'๋ฅผ ์ฌ์ฉํ์ธ์."}
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
"success": True,
|
| 52 |
+
"message": "๋ฉ์์ง๊ฐ ์ปจํ
์คํธ์ ์ถ๊ฐ๋์์ต๋๋ค.",
|
| 53 |
+
"message_id": msg_id,
|
| 54 |
+
"context_summary": context_manager.get_context_summary()
|
| 55 |
+
}
|
| 56 |
+
except ImportError:
|
| 57 |
+
return {"success": False, "error": "Context manager not available"}
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"โ ์ปจํ
์คํธ ๋ฉ์์ง ์ถ๊ฐ ์คํจ: {e}")
|
| 60 |
+
return {"success": False, "error": str(e)}
|
| 61 |
+
|
| 62 |
+
@router.get("/context/get")
|
| 63 |
+
async def get_context(
|
| 64 |
+
include_system: bool = True,
|
| 65 |
+
max_length: Optional[int] = None,
|
| 66 |
+
recent_turns: Optional[int] = None
|
| 67 |
+
):
|
| 68 |
+
"""ํ์ฌ ์ปจํ
์คํธ ์กฐํ"""
|
| 69 |
+
try:
|
| 70 |
+
try:
|
| 71 |
+
from lily_llm_core.context_manager import context_manager
|
| 72 |
+
if recent_turns:
|
| 73 |
+
context = context_manager.get_recent_context(recent_turns)
|
| 74 |
+
else:
|
| 75 |
+
context = context_manager.get_context(include_system, max_length)
|
| 76 |
+
|
| 77 |
+
return {
|
| 78 |
+
"success": True,
|
| 79 |
+
"context": context,
|
| 80 |
+
"context_summary": context_manager.get_context_summary(),
|
| 81 |
+
"memory_efficiency": context_manager.get_memory_efficiency()
|
| 82 |
+
}
|
| 83 |
+
except ImportError:
|
| 84 |
+
return {"success": False, "error": "Context manager not available"}
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"โ ์ปจํ
์คํธ ์กฐํ ์คํจ: {e}")
|
| 87 |
+
return {"success": False, "error": str(e)}
|
| 88 |
+
|
| 89 |
+
@router.get("/context/summary")
|
| 90 |
+
async def get_context_summary():
|
| 91 |
+
"""์ปจํ
์คํธ ์์ฝ ์ ๋ณด ์กฐํ"""
|
| 92 |
+
try:
|
| 93 |
+
try:
|
| 94 |
+
from lily_llm_core.context_manager import context_manager
|
| 95 |
+
return {
|
| 96 |
+
"success": True,
|
| 97 |
+
"summary": context_manager.get_context_summary(),
|
| 98 |
+
"memory_efficiency": context_manager.get_memory_efficiency()
|
| 99 |
+
}
|
| 100 |
+
except ImportError:
|
| 101 |
+
return {"success": False, "error": "Context manager not available"}
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"โ ์ปจํ
์คํธ ์์ฝ ์กฐํ ์คํจ: {e}")
|
| 104 |
+
return {"success": False, "error": str(e)}
|
| 105 |
+
|
| 106 |
+
@router.post("/context/clear")
|
| 107 |
+
async def clear_context():
|
| 108 |
+
"""์ปจํ
์คํธ ์ด๊ธฐํ"""
|
| 109 |
+
try:
|
| 110 |
+
try:
|
| 111 |
+
from lily_llm_core.context_manager import context_manager
|
| 112 |
+
context_manager.clear_context()
|
| 113 |
+
return {
|
| 114 |
+
"success": True,
|
| 115 |
+
"message": "์ปจํ
์คํธ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค."
|
| 116 |
+
}
|
| 117 |
+
except ImportError:
|
| 118 |
+
return {"success": False, "error": "Context manager not available"}
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"โ ์ปจํ
์คํธ ์ด๊ธฐํ ์คํจ: {e}")
|
| 121 |
+
return {"success": False, "error": str(e)}
|
| 122 |
+
|
| 123 |
+
@router.delete("/context/message/{message_id}")
|
| 124 |
+
async def remove_context_message(message_id: str):
|
| 125 |
+
"""์ปจํ
์คํธ์์ ํน์ ๋ฉ์์ง ์ ๊ฑฐ"""
|
| 126 |
+
try:
|
| 127 |
+
try:
|
| 128 |
+
from lily_llm_core.context_manager import context_manager
|
| 129 |
+
success = context_manager.remove_message(message_id)
|
| 130 |
+
if success:
|
| 131 |
+
return {
|
| 132 |
+
"success": True,
|
| 133 |
+
"message": "๋ฉ์์ง๊ฐ ์ ๊ฑฐ๋์์ต๋๋ค.",
|
| 134 |
+
"context_summary": context_manager.get_context_summary()
|
| 135 |
+
}
|
| 136 |
+
else:
|
| 137 |
+
return {"success": False, "error": "๋ฉ์์ง๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."}
|
| 138 |
+
except ImportError:
|
| 139 |
+
return {"success": False, "error": "Context manager not available"}
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"โ ๋ฉ์์ง ์ ๊ฑฐ ์คํจ: {e}")
|
| 142 |
+
return {"success": False, "error": str(e)}
|
| 143 |
+
|
| 144 |
+
@router.put("/context/message/{message_id}")
|
| 145 |
+
async def edit_context_message(
|
| 146 |
+
message_id: str,
|
| 147 |
+
new_content: str = Form(...)
|
| 148 |
+
):
|
| 149 |
+
"""์ปจํ
์คํธ ๋ฉ์์ง ์์ """
|
| 150 |
+
try:
|
| 151 |
+
try:
|
| 152 |
+
from lily_llm_core.context_manager import context_manager
|
| 153 |
+
success = context_manager.edit_message(message_id, new_content)
|
| 154 |
+
if success:
|
| 155 |
+
return {
|
| 156 |
+
"success": True,
|
| 157 |
+
"message": "๋ฉ์์ง๊ฐ ์์ ๋์์ต๋๋ค.",
|
| 158 |
+
"context_summary": context_manager.get_context_summary()
|
| 159 |
+
}
|
| 160 |
+
else:
|
| 161 |
+
return {"success": False, "error": "๋ฉ์์ง๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."}
|
| 162 |
+
except ImportError:
|
| 163 |
+
return {"success": False, "error": "Context manager not available"}
|
| 164 |
+
except Exception as e:
|
| 165 |
+
logger.error(f"โ ๋ฉ์์ง ์์ ์คํจ: {e}")
|
| 166 |
+
return {"success": False, "error": str(e)}
|
| 167 |
+
|
| 168 |
+
@router.get("/context/search")
|
| 169 |
+
async def search_context(query: str, max_results: int = 5):
|
| 170 |
+
"""์ปจํ
์คํธ ๋ด์์ ๊ฒ์"""
|
| 171 |
+
try:
|
| 172 |
+
try:
|
| 173 |
+
from lily_llm_core.context_manager import context_manager
|
| 174 |
+
results = context_manager.search_context(query, max_results)
|
| 175 |
+
return {
|
| 176 |
+
"success": True,
|
| 177 |
+
"query": query,
|
| 178 |
+
"results": results,
|
| 179 |
+
"total_results": len(results)
|
| 180 |
+
}
|
| 181 |
+
except ImportError:
|
| 182 |
+
return {"success": False, "error": "Context manager not available"}
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"โ ์ปจํ
์คํธ ๊ฒ์ ์คํจ: {e}")
|
| 185 |
+
return {"success": False, "error": str(e)}
|
| 186 |
+
|
| 187 |
+
@router.post("/context/export")
|
| 188 |
+
async def export_context(file_path: str = Form(None)):
|
| 189 |
+
"""์ปจํ
์คํธ ๋ด๋ณด๋ด๊ธฐ"""
|
| 190 |
+
try:
|
| 191 |
+
try:
|
| 192 |
+
from lily_llm_core.context_manager import context_manager
|
| 193 |
+
exported_path = context_manager.export_context(file_path)
|
| 194 |
+
return {
|
| 195 |
+
"success": True,
|
| 196 |
+
"message": "์ปจํ
์คํธ๊ฐ ๋ด๋ณด๋ด์ก์ต๋๋ค.",
|
| 197 |
+
"file_path": exported_path
|
| 198 |
+
}
|
| 199 |
+
except ImportError:
|
| 200 |
+
return {"success": False, "error": "Context manager not available"}
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"โ ์ปจํ
์คํธ ๋ด๋ณด๋ด๊ธฐ ์คํจ: {e}")
|
| 203 |
+
return {"success": False, "error": str(e)}
|
| 204 |
+
|
| 205 |
+
@router.post("/context/import")
|
| 206 |
+
async def import_context(file_path: str = Form(...)):
|
| 207 |
+
"""์ปจํ
์คํธ ๊ฐ์ ธ์ค๊ธฐ"""
|
| 208 |
+
try:
|
| 209 |
+
try:
|
| 210 |
+
from lily_llm_core.context_manager import context_manager
|
| 211 |
+
success = context_manager.import_context(file_path)
|
| 212 |
+
if success:
|
| 213 |
+
return {
|
| 214 |
+
"success": True,
|
| 215 |
+
"message": "์ปจํ
์คํธ๊ฐ ๊ฐ์ ธ์์ก์ต๋๋ค.",
|
| 216 |
+
"context_summary": context_manager.get_context_summary()
|
| 217 |
+
}
|
| 218 |
+
else:
|
| 219 |
+
return {"success": False, "error": "์ปจํ
์คํธ ๊ฐ์ ธ์ค๊ธฐ์ ์คํจํ์ต๋๋ค."}
|
| 220 |
+
except ImportError:
|
| 221 |
+
return {"success": False, "error": "Context manager not available"}
|
| 222 |
+
except Exception as e:
|
| 223 |
+
logger.error(f"โ ์ปจํ
์คํธ ๊ฐ์ ธ์ค๊ธฐ ์คํจ: {e}")
|
| 224 |
+
return {"success": False, "error": str(e)}
|
| 225 |
+
|
| 226 |
+
@router.post("/context/compress")
|
| 227 |
+
async def compress_context(compression_ratio: float = Form(0.5)):
|
| 228 |
+
"""์ปจํ
์คํธ ์์ถ"""
|
| 229 |
+
try:
|
| 230 |
+
try:
|
| 231 |
+
from lily_llm_core.context_manager import context_manager
|
| 232 |
+
success = context_manager.compress_context(compression_ratio)
|
| 233 |
+
if success:
|
| 234 |
+
return {
|
| 235 |
+
"success": True,
|
| 236 |
+
"message": "์ปจํ
์คํธ๊ฐ ์์ถ๋์์ต๋๋ค.",
|
| 237 |
+
"compression_ratio": compression_ratio,
|
| 238 |
+
"context_summary": context_manager.get_context_summary()
|
| 239 |
+
}
|
| 240 |
+
else:
|
| 241 |
+
return {"success": False, "error": "์ปจํ
์คํธ ์์ถ์ ์คํจํ์ต๋๋ค."}
|
| 242 |
+
except ImportError:
|
| 243 |
+
return {"success": False, "error": "Context manager not available"}
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logger.error(f"โ ์ปจํ
์คํธ ์์ถ ์คํจ: {e}")
|
| 246 |
+
return {"success": False, "error": str(e)}
|
| 247 |
+
|
| 248 |
+
@router.post("/context/optimize")
|
| 249 |
+
async def optimize_context(optimization_strategy: str = Form("memory")):
|
| 250 |
+
"""์ปจํ
์คํธ ์ต๏ฟฝ๏ฟฝํ"""
|
| 251 |
+
try:
|
| 252 |
+
try:
|
| 253 |
+
from lily_llm_core.context_manager import context_manager
|
| 254 |
+
success = context_manager.optimize_context(optimization_strategy)
|
| 255 |
+
if success:
|
| 256 |
+
return {
|
| 257 |
+
"success": True,
|
| 258 |
+
"message": "์ปจํ
์คํธ๊ฐ ์ต์ ํ๋์์ต๋๋ค.",
|
| 259 |
+
"strategy": optimization_strategy,
|
| 260 |
+
"context_summary": context_manager.get_context_summary()
|
| 261 |
+
}
|
| 262 |
+
else:
|
| 263 |
+
return {"success": False, "error": "์ปจํ
์คํธ ์ต์ ํ์ ์คํจํ์ต๋๋ค."}
|
| 264 |
+
except ImportError:
|
| 265 |
+
return {"success": False, "error": "Context manager not available"}
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"โ ์ปจํ
์คํธ ์ต์ ํ ์คํจ: {e}")
|
| 268 |
+
return {"success": False, "error": str(e)}
|
| 269 |
+
|
| 270 |
+
@router.get("/context/health")
|
| 271 |
+
async def get_context_health():
|
| 272 |
+
"""์ปจํ
์คํธ ์์คํ
์ํ ํ์ธ"""
|
| 273 |
+
try:
|
| 274 |
+
try:
|
| 275 |
+
from lily_llm_core.context_manager import context_manager
|
| 276 |
+
health_info = context_manager.get_health_info()
|
| 277 |
+
return {
|
| 278 |
+
"success": True,
|
| 279 |
+
"health": health_info
|
| 280 |
+
}
|
| 281 |
+
except ImportError:
|
| 282 |
+
return {"success": False, "error": "Context manager not available"}
|
| 283 |
+
except Exception as e:
|
| 284 |
+
logger.error(f"โ ์ปจํ
์คํธ ์ํ ํ์ธ ์คํจ: {e}")
|
| 285 |
+
return {"success": False, "error": str(e)}
|
| 286 |
+
|
| 287 |
+
@router.get("/context/analytics")
|
| 288 |
+
async def get_context_analytics():
|
| 289 |
+
"""์ปจํ
์คํธ ๋ถ์ ์ ๋ณด ์กฐํ"""
|
| 290 |
+
try:
|
| 291 |
+
try:
|
| 292 |
+
from lily_llm_core.context_manager import context_manager
|
| 293 |
+
analytics = context_manager.get_analytics()
|
| 294 |
+
return {
|
| 295 |
+
"success": True,
|
| 296 |
+
"analytics": analytics
|
| 297 |
+
}
|
| 298 |
+
except ImportError:
|
| 299 |
+
return {"success": False, "error": "Context manager not available"}
|
| 300 |
+
except Exception as e:
|
| 301 |
+
logger.error(f"โ ์ปจํ
์คํธ ๋ถ์ ์กฐํ ์คํจ: {e}")
|
| 302 |
+
return {"success": False, "error": str(e)}
|
lily_llm_api/api/routers/context_router.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Context management router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Form
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from ...models.schemas import (
|
| 9 |
+
ContextStatusResponse, ContextHistoryResponse,
|
| 10 |
+
AutoCleanupConfigResponse, AutoCleanupConfigRequest
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
router = APIRouter()
|
| 15 |
+
|
| 16 |
+
@router.get("/context/status", response_model=ContextStatusResponse)
|
| 17 |
+
async def get_context_status():
|
| 18 |
+
"""์ปจํ
์คํธ ๊ด๋ฆฌ์ ์ํ ํ์ธ"""
|
| 19 |
+
try:
|
| 20 |
+
try:
|
| 21 |
+
from lily_llm_core.context_manager import context_manager
|
| 22 |
+
if not context_manager:
|
| 23 |
+
return ContextStatusResponse(
|
| 24 |
+
status="error",
|
| 25 |
+
context_manager_available=False,
|
| 26 |
+
total_sessions=0,
|
| 27 |
+
sessions={},
|
| 28 |
+
max_tokens=0,
|
| 29 |
+
max_turns=0,
|
| 30 |
+
strategy="unknown",
|
| 31 |
+
message="Context manager not available"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# ์ธ์
๋ณ ์ ๋ณด ์์ง
|
| 35 |
+
session_info = {}
|
| 36 |
+
for session_id, conversation in context_manager.session_conversations.items():
|
| 37 |
+
session_info[session_id] = {
|
| 38 |
+
"turns": len(conversation),
|
| 39 |
+
"user_messages": len([t for t in conversation if t.role == "user"]),
|
| 40 |
+
"assistant_messages": len([t for t in conversation if t.role == "assistant"])
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
return ContextStatusResponse(
|
| 44 |
+
status="success",
|
| 45 |
+
context_manager_available=True,
|
| 46 |
+
total_sessions=len(context_manager.session_conversations),
|
| 47 |
+
sessions=session_info,
|
| 48 |
+
max_tokens=getattr(context_manager, 'max_tokens', 0),
|
| 49 |
+
max_turns=getattr(context_manager, 'max_turns', 0),
|
| 50 |
+
strategy=getattr(context_manager, 'strategy', 'unknown')
|
| 51 |
+
)
|
| 52 |
+
except ImportError:
|
| 53 |
+
return ContextStatusResponse(
|
| 54 |
+
status="error",
|
| 55 |
+
context_manager_available=False,
|
| 56 |
+
total_sessions=0,
|
| 57 |
+
sessions={},
|
| 58 |
+
max_tokens=0,
|
| 59 |
+
max_turns=0,
|
| 60 |
+
strategy="unknown",
|
| 61 |
+
message="Context manager import failed"
|
| 62 |
+
)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"์ปจํ
์คํธ ์ํ ํ์ธ ์คํจ: {e}")
|
| 65 |
+
return ContextStatusResponse(
|
| 66 |
+
status="error",
|
| 67 |
+
context_manager_available=False,
|
| 68 |
+
total_sessions=0,
|
| 69 |
+
sessions={},
|
| 70 |
+
max_tokens=0,
|
| 71 |
+
max_turns=0,
|
| 72 |
+
strategy="unknown",
|
| 73 |
+
message=str(e)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
@router.get("/context/history", response_model=ContextHistoryResponse)
|
| 77 |
+
async def get_context_history(session_id: str = None):
|
| 78 |
+
"""์ปจํ
์คํธ ํ์คํ ๋ฆฌ ์กฐํ"""
|
| 79 |
+
try:
|
| 80 |
+
try:
|
| 81 |
+
from lily_llm_core.context_manager import context_manager
|
| 82 |
+
if not context_manager:
|
| 83 |
+
return ContextHistoryResponse(
|
| 84 |
+
status="error",
|
| 85 |
+
context="",
|
| 86 |
+
history_length=0,
|
| 87 |
+
message="Context manager not available"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if session_id:
|
| 91 |
+
# ํน์ ์ธ์
์ ์ปจํ
์คํธ๋ง ์กฐํ
|
| 92 |
+
context = context_manager.get_context(include_system=True, max_length=4000, session_id=session_id)
|
| 93 |
+
session_summary = context_manager.get_context_summary(session_id)
|
| 94 |
+
return ContextHistoryResponse(
|
| 95 |
+
status="success",
|
| 96 |
+
session_id=session_id,
|
| 97 |
+
context=context,
|
| 98 |
+
history_length=session_summary.get("total_turns", 0),
|
| 99 |
+
session_summary=session_summary
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
# ์ ์ฒด ์ปจํ
์คํธ ์กฐํ
|
| 103 |
+
context = context_manager.get_context(include_system=True, max_length=4000)
|
| 104 |
+
return ContextHistoryResponse(
|
| 105 |
+
status="success",
|
| 106 |
+
context=context,
|
| 107 |
+
history_length=len(context_manager.conversation_history),
|
| 108 |
+
all_sessions=True
|
| 109 |
+
)
|
| 110 |
+
except ImportError:
|
| 111 |
+
return ContextHistoryResponse(
|
| 112 |
+
status="error",
|
| 113 |
+
context="",
|
| 114 |
+
history_length=0,
|
| 115 |
+
message="Context manager import failed"
|
| 116 |
+
)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"์ปจํ
์คํธ ํ์คํ ๋ฆฌ ์กฐํ ์คํจ: {e}")
|
| 119 |
+
return ContextHistoryResponse(
|
| 120 |
+
status="error",
|
| 121 |
+
context="",
|
| 122 |
+
history_length=0,
|
| 123 |
+
message=str(e)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
@router.get("/context/auto-cleanup", response_model=AutoCleanupConfigResponse)
|
| 127 |
+
async def get_auto_cleanup_config():
|
| 128 |
+
"""์๋ ์ ๋ฆฌ ์ค์ ์กฐํ"""
|
| 129 |
+
try:
|
| 130 |
+
try:
|
| 131 |
+
from lily_llm_core.context_manager import context_manager
|
| 132 |
+
if not context_manager:
|
| 133 |
+
return AutoCleanupConfigResponse(
|
| 134 |
+
status="error",
|
| 135 |
+
auto_cleanup_config={},
|
| 136 |
+
message="Context manager not available"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
config = context_manager.get_auto_cleanup_config()
|
| 140 |
+
return AutoCleanupConfigResponse(
|
| 141 |
+
status="success",
|
| 142 |
+
auto_cleanup_config=config
|
| 143 |
+
)
|
| 144 |
+
except ImportError:
|
| 145 |
+
return AutoCleanupConfigResponse(
|
| 146 |
+
status="error",
|
| 147 |
+
auto_cleanup_config={},
|
| 148 |
+
message="Context manager import failed"
|
| 149 |
+
)
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"์๋ ์ ๋ฆฌ ์ค์ ์กฐํ ์คํจ: {e}")
|
| 152 |
+
return AutoCleanupConfigResponse(
|
| 153 |
+
status="error",
|
| 154 |
+
auto_cleanup_config={},
|
| 155 |
+
message=str(e)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
@router.post("/context/auto-cleanup")
|
| 159 |
+
async def set_auto_cleanup_config(
|
| 160 |
+
enabled: bool = Form(True),
|
| 161 |
+
interval_turns: int = Form(8),
|
| 162 |
+
interval_time: int = Form(300),
|
| 163 |
+
strategy: str = Form("smart")
|
| 164 |
+
):
|
| 165 |
+
"""์๋ ์ ๋ฆฌ ์ค์ ๋ณ๊ฒฝ"""
|
| 166 |
+
try:
|
| 167 |
+
try:
|
| 168 |
+
from lily_llm_core.context_manager import context_manager
|
| 169 |
+
if not context_manager:
|
| 170 |
+
raise HTTPException(status_code=500, detail="Context manager not available")
|
| 171 |
+
|
| 172 |
+
success = context_manager.set_auto_cleanup_config(
|
| 173 |
+
enabled=enabled,
|
| 174 |
+
interval_turns=interval_turns,
|
| 175 |
+
interval_time=interval_time,
|
| 176 |
+
strategy=strategy
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if success:
|
| 180 |
+
return {"status": "success", "message": "์๋ ์ ๋ฆฌ ์ค์ ๋ณ๊ฒฝ ์๋ฃ"}
|
| 181 |
+
else:
|
| 182 |
+
raise HTTPException(status_code=500, detail="์๋ ์ ๋ฆฌ ์ค์ ๋ณ๊ฒฝ ์คํจ")
|
| 183 |
+
except ImportError:
|
| 184 |
+
raise HTTPException(status_code=500, detail="Context manager import failed")
|
| 185 |
+
except Exception as e:
|
| 186 |
+
logger.error(f"์๋ ์ ๋ฆฌ ์ค์ ๋ณ๊ฒฝ ์คํจ: {e}")
|
| 187 |
+
raise HTTPException(status_code=500, detail=f"์๋ ์ ๋ฆฌ ์ค์ ๋ณ๊ฒฝ ์คํจ: {str(e)}")
|
| 188 |
+
|
| 189 |
+
@router.post("/context/cleanup")
|
| 190 |
+
async def cleanup_context(session_id: str = Form(None)):
|
| 191 |
+
"""์ปจํ
์คํธ ์ ๋ฆฌ"""
|
| 192 |
+
try:
|
| 193 |
+
try:
|
| 194 |
+
from lily_llm_core.context_manager import context_manager
|
| 195 |
+
if not context_manager:
|
| 196 |
+
raise HTTPException(status_code=500, detail="Context manager not available")
|
| 197 |
+
|
| 198 |
+
if session_id:
|
| 199 |
+
# ํน์ ์ธ์
์ ๋ฆฌ
|
| 200 |
+
success = context_manager.cleanup_session(session_id)
|
| 201 |
+
if success:
|
| 202 |
+
return {"status": "success", "message": f"์ธ์
{session_id} ์ ๋ฆฌ ์๋ฃ"}
|
| 203 |
+
else:
|
| 204 |
+
raise HTTPException(status_code=500, detail=f"์ธ์
{session_id} ์ ๋ฆฌ ์คํจ")
|
| 205 |
+
else:
|
| 206 |
+
# ์ ์ฒด ์ปจํ
์คํธ ์ ๋ฆฌ
|
| 207 |
+
success = context_manager.cleanup_context()
|
| 208 |
+
if success:
|
| 209 |
+
return {"status": "success", "message": "์ ์ฒด ์ปจํ
์คํธ ์ ๋ฆฌ ์๋ฃ"}
|
| 210 |
+
else:
|
| 211 |
+
raise HTTPException(status_code=500, detail="์ ์ฒด ์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ")
|
| 212 |
+
except ImportError:
|
| 213 |
+
raise HTTPException(status_code=500, detail="Context manager import failed")
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.error(f"์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ: {e}")
|
| 216 |
+
raise HTTPException(status_code=500, detail=f"์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ: {str(e)}")
|
| 217 |
+
|
| 218 |
+
@router.post("/context/summary")
|
| 219 |
+
async def generate_context_summary(session_id: str = Form(...)):
|
| 220 |
+
"""์ปจํ
์คํธ ์์ฝ ์์ฑ"""
|
| 221 |
+
try:
|
| 222 |
+
try:
|
| 223 |
+
from lily_llm_core.context_manager import context_manager
|
| 224 |
+
if not context_manager:
|
| 225 |
+
raise HTTPException(status_code=500, detail="Context manager not available")
|
| 226 |
+
|
| 227 |
+
summary = context_manager.generate_summary(session_id)
|
| 228 |
+
if summary:
|
| 229 |
+
return {"status": "success", "summary": summary}
|
| 230 |
+
else:
|
| 231 |
+
raise HTTPException(status_code=500, detail="์ปจํ
์คํธ ์์ฝ ์์ฑ ์คํจ")
|
| 232 |
+
except ImportError:
|
| 233 |
+
raise HTTPException(status_code=500, detail="Context manager import failed")
|
| 234 |
+
except Exception as e:
|
| 235 |
+
logger.error(f"์ปจํ
์คํธ ์์ฝ ์์ฑ ์คํจ: {e}")
|
| 236 |
+
raise HTTPException(status_code=500, detail=f"์ปจํ
์คํธ ์์ฝ ์์ฑ ์คํจ: {str(e)}")
|
| 237 |
+
|
| 238 |
+
@router.delete("/context/session/{session_id}")
|
| 239 |
+
async def delete_session(session_id: str):
|
| 240 |
+
"""ํน์ ์ธ์
์ญ์ """
|
| 241 |
+
try:
|
| 242 |
+
try:
|
| 243 |
+
from lily_llm_core.context_manager import context_manager
|
| 244 |
+
if not context_manager:
|
| 245 |
+
raise HTTPException(status_code=500, detail="Context manager not available")
|
| 246 |
+
|
| 247 |
+
success = context_manager.delete_session(session_id)
|
| 248 |
+
if success:
|
| 249 |
+
return {"status": "success", "message": f"์ธ์
{session_id} ์ญ์ ์๋ฃ"}
|
| 250 |
+
else:
|
| 251 |
+
raise HTTPException(status_code=500, detail=f"์ธ์
{session_id} ์ญ์ ์คํจ")
|
| 252 |
+
except ImportError:
|
| 253 |
+
raise HTTPException(status_code=500, detail="Context manager import failed")
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error(f"์ธ์
์ญ์ ์คํจ: {e}")
|
| 256 |
+
raise HTTPException(status_code=500, detail=f"์ธ์
์ญ์ ์คํจ: {str(e)}")
|
| 257 |
+
|
| 258 |
+
@router.get("/context/sessions")
|
| 259 |
+
async def list_sessions():
|
| 260 |
+
"""์ฌ์ฉ ๊ฐ๋ฅํ ์ธ์
๋ชฉ๋ก"""
|
| 261 |
+
try:
|
| 262 |
+
try:
|
| 263 |
+
from lily_llm_core.context_manager import context_manager
|
| 264 |
+
if not context_manager:
|
| 265 |
+
raise HTTPException(status_code=500, detail="Context manager not available")
|
| 266 |
+
|
| 267 |
+
sessions = list(context_manager.session_conversations.keys())
|
| 268 |
+
return {"status": "success", "sessions": sessions}
|
| 269 |
+
except ImportError:
|
| 270 |
+
raise HTTPException(status_code=500, detail="Context manager import failed")
|
| 271 |
+
except Exception as e:
|
| 272 |
+
logger.error(f"์ธ์
๋ชฉ๋ก ์กฐํ ์คํจ: {e}")
|
| 273 |
+
raise HTTPException(status_code=500, detail=f"์ธ์
๋ชฉ๋ก ์กฐํ ์คํจ: {str(e)}")
|
lily_llm_api/api/routers/document_router.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Document processing and RAG router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from ...models.schemas import (
|
| 10 |
+
DocumentUploadResponse, RAGQueryRequest, RAGQueryResponse,
|
| 11 |
+
DocumentProcessResponse, MultimodalRAGResponse
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
@router.post("/document/upload", response_model=DocumentUploadResponse)
|
| 18 |
+
async def upload_document(
|
| 19 |
+
file: UploadFile = File(...),
|
| 20 |
+
user_id: str = Form("anonymous"),
|
| 21 |
+
room_id: str = Form("default")
|
| 22 |
+
):
|
| 23 |
+
"""๋ฌธ์ ์
๋ก๋ ๋ฐ ์ฒ๋ฆฌ"""
|
| 24 |
+
try:
|
| 25 |
+
start_time = time.time()
|
| 26 |
+
|
| 27 |
+
# ํ์ผ ์ฝ๊ธฐ
|
| 28 |
+
content = await file.read()
|
| 29 |
+
filename = file.filename
|
| 30 |
+
|
| 31 |
+
# ๋ฌธ์ ์ฒ๋ฆฌ๊ธฐ ์ฌ์ฉ
|
| 32 |
+
try:
|
| 33 |
+
from lily_llm_core.document_processor import document_processor
|
| 34 |
+
|
| 35 |
+
# ๋ฌธ์ ์ฒ๋ฆฌ
|
| 36 |
+
result = document_processor.process_document(
|
| 37 |
+
content=content,
|
| 38 |
+
filename=filename,
|
| 39 |
+
user_id=user_id,
|
| 40 |
+
room_id=room_id
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if result.get("success"):
|
| 44 |
+
processing_time = time.time() - start_time
|
| 45 |
+
return DocumentUploadResponse(
|
| 46 |
+
success=True,
|
| 47 |
+
document_id=result.get("document_id", "unknown"),
|
| 48 |
+
message="๋ฌธ์ ์
๋ก๋ ๋ฐ ์ฒ๋ฆฌ ์๋ฃ",
|
| 49 |
+
chunks=result.get("chunks", 0),
|
| 50 |
+
latex_count=result.get("latex_count", 0),
|
| 51 |
+
auto_response=result.get("auto_response")
|
| 52 |
+
)
|
| 53 |
+
else:
|
| 54 |
+
return DocumentUploadResponse(
|
| 55 |
+
success=False,
|
| 56 |
+
document_id="",
|
| 57 |
+
message="๋ฌธ์ ์ฒ๋ฆฌ ์คํจ",
|
| 58 |
+
error=result.get("error", "Unknown error")
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
except ImportError:
|
| 62 |
+
return DocumentUploadResponse(
|
| 63 |
+
success=False,
|
| 64 |
+
document_id="",
|
| 65 |
+
message="๋ฌธ์ ์ฒ๋ฆฌ๊ธฐ import ์คํจ",
|
| 66 |
+
error="Document processor not available"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"๋ฌธ์ ์
๋ก๋ ์คํจ: {e}")
|
| 71 |
+
return DocumentUploadResponse(
|
| 72 |
+
success=False,
|
| 73 |
+
document_id="",
|
| 74 |
+
message="๋ฌธ์ ์
๋ก๋ ์ค ์ค๋ฅ ๋ฐ์",
|
| 75 |
+
error=str(e)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
@router.post("/rag/query", response_model=RAGQueryResponse)
|
| 79 |
+
async def rag_query(
|
| 80 |
+
query: str = Form(...),
|
| 81 |
+
user_id: str = Form("anonymous"),
|
| 82 |
+
room_id: str = Form("default"),
|
| 83 |
+
max_results: int = Form(5),
|
| 84 |
+
include_sources: bool = Form(True)
|
| 85 |
+
):
|
| 86 |
+
"""RAG ์ฟผ๋ฆฌ ์ฒ๋ฆฌ"""
|
| 87 |
+
try:
|
| 88 |
+
start_time = time.time()
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 92 |
+
|
| 93 |
+
# RAG ์ฟผ๋ฆฌ ์คํ
|
| 94 |
+
result = rag_processor.query(
|
| 95 |
+
query=query,
|
| 96 |
+
user_id=user_id,
|
| 97 |
+
room_id=room_id,
|
| 98 |
+
max_results=max_results,
|
| 99 |
+
include_sources=include_sources
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if result.get("success"):
|
| 103 |
+
processing_time = time.time() - start_time
|
| 104 |
+
return RAGQueryResponse(
|
| 105 |
+
success=True,
|
| 106 |
+
response=result.get("response", ""),
|
| 107 |
+
sources=result.get("sources", []),
|
| 108 |
+
search_results=len(result.get("sources", [])),
|
| 109 |
+
processing_time=processing_time
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
return RAGQueryResponse(
|
| 113 |
+
success=False,
|
| 114 |
+
response="",
|
| 115 |
+
sources=[],
|
| 116 |
+
search_results=0,
|
| 117 |
+
processing_time=0,
|
| 118 |
+
error=result.get("error", "RAG ์ฟผ๋ฆฌ ์คํจ")
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
except ImportError:
|
| 122 |
+
return RAGQueryResponse(
|
| 123 |
+
success=False,
|
| 124 |
+
response="",
|
| 125 |
+
sources=[],
|
| 126 |
+
search_results=0,
|
| 127 |
+
processing_time=0,
|
| 128 |
+
error="RAG processor not available"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"RAG ์ฟผ๋ฆฌ ์คํจ: {e}")
|
| 133 |
+
return RAGQueryResponse(
|
| 134 |
+
success=False,
|
| 135 |
+
response="",
|
| 136 |
+
sources=[],
|
| 137 |
+
search_results=0,
|
| 138 |
+
processing_time=0,
|
| 139 |
+
error=str(e)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
@router.post("/rag/generate", response_model=RAGQueryResponse)
|
| 143 |
+
async def rag_generate(
|
| 144 |
+
prompt: str = Form(...),
|
| 145 |
+
user_id: str = Form("anonymous"),
|
| 146 |
+
room_id: str = Form("default"),
|
| 147 |
+
max_results: int = Form(5)
|
| 148 |
+
):
|
| 149 |
+
"""RAG ๊ธฐ๋ฐ ํ
์คํธ ์์ฑ"""
|
| 150 |
+
try:
|
| 151 |
+
start_time = time.time()
|
| 152 |
+
|
| 153 |
+
try:
|
| 154 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 155 |
+
|
| 156 |
+
# RAG ์์ฑ ์คํ
|
| 157 |
+
result = rag_processor.generate_with_context(
|
| 158 |
+
prompt=prompt,
|
| 159 |
+
user_id=user_id,
|
| 160 |
+
room_id=room_id,
|
| 161 |
+
max_results=max_results
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if result.get("success"):
|
| 165 |
+
processing_time = time.time() - start_time
|
| 166 |
+
return RAGQueryResponse(
|
| 167 |
+
success=True,
|
| 168 |
+
response=result.get("response", ""),
|
| 169 |
+
sources=result.get("sources", []),
|
| 170 |
+
search_results=len(result.get("sources", [])),
|
| 171 |
+
processing_time=processing_time
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
return RAGQueryResponse(
|
| 175 |
+
success=False,
|
| 176 |
+
response="",
|
| 177 |
+
sources=[],
|
| 178 |
+
search_results=0,
|
| 179 |
+
processing_time=0,
|
| 180 |
+
error=result.get("error", "RAG ์์ฑ ์คํจ")
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
except ImportError:
|
| 184 |
+
return RAGQueryResponse(
|
| 185 |
+
success=False,
|
| 186 |
+
response="",
|
| 187 |
+
sources=[],
|
| 188 |
+
search_results=0,
|
| 189 |
+
processing_time=0,
|
| 190 |
+
error="RAG processor not available"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
logger.error(f"RAG ์์ฑ ์คํจ: {e}")
|
| 195 |
+
return RAGQueryResponse(
|
| 196 |
+
success=False,
|
| 197 |
+
response="",
|
| 198 |
+
sources=[],
|
| 199 |
+
search_results=0,
|
| 200 |
+
processing_time=0,
|
| 201 |
+
error=str(e)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
@router.post("/rag/summary")
|
| 205 |
+
async def generate_rag_summary(
|
| 206 |
+
user_id: str = Form("anonymous"),
|
| 207 |
+
room_id: str = Form("default")
|
| 208 |
+
):
|
| 209 |
+
"""RAG ๋ฌธ์ ์์ฝ ์์ฑ"""
|
| 210 |
+
try:
|
| 211 |
+
try:
|
| 212 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 213 |
+
|
| 214 |
+
# RAG ์์ฝ ์์ฑ
|
| 215 |
+
result = rag_processor.generate_summary(
|
| 216 |
+
user_id=user_id,
|
| 217 |
+
room_id=room_id
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if result.get("success"):
|
| 221 |
+
return {"status": "success", "summary": result.get("summary", "")}
|
| 222 |
+
else:
|
| 223 |
+
raise HTTPException(status_code=500, detail=result.get("error", "RAG ์์ฝ ์์ฑ ์คํจ"))
|
| 224 |
+
|
| 225 |
+
except ImportError:
|
| 226 |
+
raise HTTPException(status_code=500, detail="RAG processor not available")
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"RAG ์์ฝ ์์ฑ ์คํจ: {e}")
|
| 230 |
+
raise HTTPException(status_code=500, detail=f"RAG ์์ฝ ์์ฑ ์คํจ: {str(e)}")
|
| 231 |
+
|
| 232 |
+
@router.post("/rag/clear")
|
| 233 |
+
async def clear_rag_context(
|
| 234 |
+
user_id: str = Form("anonymous"),
|
| 235 |
+
room_id: str = Form("default")
|
| 236 |
+
):
|
| 237 |
+
"""RAG ์ปจํ
์คํธ ์ ๋ฆฌ"""
|
| 238 |
+
try:
|
| 239 |
+
try:
|
| 240 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 241 |
+
|
| 242 |
+
# RAG ์ปจํ
์คํธ ์ ๋ฆฌ
|
| 243 |
+
success = rag_processor.clear_context(
|
| 244 |
+
user_id=user_id,
|
| 245 |
+
room_id=room_id
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if success:
|
| 249 |
+
return {"status": "success", "message": "RAG ์ปจํ
์คํธ ์ ๋ฆฌ ์๋ฃ"}
|
| 250 |
+
else:
|
| 251 |
+
raise HTTPException(status_code=500, detail="RAG ์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ")
|
| 252 |
+
|
| 253 |
+
except ImportError:
|
| 254 |
+
raise HTTPException(status_code=500, detail="RAG processor not available")
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
logger.error(f"RAG ์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ: {e}")
|
| 258 |
+
raise HTTPException(status_code=500, detail=f"RAG ์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ: {str(e)}")
|
| 259 |
+
|
| 260 |
+
@router.post("/rag/batch-process")
|
| 261 |
+
async def batch_process_documents(
|
| 262 |
+
files: List[UploadFile] = File(...),
|
| 263 |
+
user_id: str = Form("anonymous"),
|
| 264 |
+
room_id: str = Form("default")
|
| 265 |
+
):
|
| 266 |
+
"""์ฌ๋ฌ ๋ฌธ์ ์ผ๊ด ์ฒ๋ฆฌ"""
|
| 267 |
+
try:
|
| 268 |
+
start_time = time.time()
|
| 269 |
+
results = []
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
from lily_llm_core.document_processor import document_processor
|
| 273 |
+
|
| 274 |
+
for file in files:
|
| 275 |
+
content = await file.read()
|
| 276 |
+
filename = file.filename
|
| 277 |
+
|
| 278 |
+
result = document_processor.process_document(
|
| 279 |
+
content=content,
|
| 280 |
+
filename=filename,
|
| 281 |
+
user_id=user_id,
|
| 282 |
+
room_id=room_id
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
results.append({
|
| 286 |
+
"filename": filename,
|
| 287 |
+
"success": result.get("success", False),
|
| 288 |
+
"document_id": result.get("document_id", ""),
|
| 289 |
+
"chunks": result.get("chunks", 0),
|
| 290 |
+
"error": result.get("error")
|
| 291 |
+
})
|
| 292 |
+
|
| 293 |
+
processing_time = time.time() - start_time
|
| 294 |
+
return {
|
| 295 |
+
"status": "success",
|
| 296 |
+
"results": results,
|
| 297 |
+
"total_files": len(files),
|
| 298 |
+
"processing_time": processing_time
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
except ImportError:
|
| 302 |
+
raise HTTPException(status_code=500, detail="Document processor not available")
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logger.error(f"์ผ๊ด ๋ฌธ์ ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 306 |
+
raise HTTPException(status_code=500, detail=f"์ผ๊ด ๋ฌธ์ ์ฒ๋ฆฌ ์คํจ: {str(e)}")
|
| 307 |
+
|
| 308 |
+
@router.get("/rag/search-history")
|
| 309 |
+
async def search_rag_history(
|
| 310 |
+
user_id: str = "anonymous",
|
| 311 |
+
room_id: str = "default",
|
| 312 |
+
query: str = "",
|
| 313 |
+
limit: int = 10
|
| 314 |
+
):
|
| 315 |
+
"""RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ"""
|
| 316 |
+
try:
|
| 317 |
+
try:
|
| 318 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 319 |
+
|
| 320 |
+
# RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ
|
| 321 |
+
history = rag_processor.get_search_history(
|
| 322 |
+
user_id=user_id,
|
| 323 |
+
room_id=room_id,
|
| 324 |
+
query=query,
|
| 325 |
+
limit=limit
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
return {"status": "success", "history": history}
|
| 329 |
+
|
| 330 |
+
except ImportError:
|
| 331 |
+
raise HTTPException(status_code=500, detail="RAG processor not available")
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
logger.error(f"RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ ์คํจ: {e}")
|
| 335 |
+
raise HTTPException(status_code=500, detail=f"RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ ์คํจ: {str(e)}")
|
| 336 |
+
|
| 337 |
+
@router.post("/multimodal-rag/upload")
|
| 338 |
+
async def upload_multimodal_document(
|
| 339 |
+
file: UploadFile = File(...),
|
| 340 |
+
user_id: str = Form("anonymous"),
|
| 341 |
+
room_id: str = Form("default")
|
| 342 |
+
):
|
| 343 |
+
"""๋ฉํฐ๋ชจ๋ฌ ๋ฌธ์ ์
๋ก๋"""
|
| 344 |
+
try:
|
| 345 |
+
start_time = time.time()
|
| 346 |
+
|
| 347 |
+
# ํ์ผ ์ฝ๊ธฐ
|
| 348 |
+
content = await file.read()
|
| 349 |
+
filename = file.filename
|
| 350 |
+
|
| 351 |
+
try:
|
| 352 |
+
from lily_llm_core.hybrid_rag_processor import hybrid_rag_processor
|
| 353 |
+
|
| 354 |
+
# ๋ฉํฐ๋ชจ๋ฌ ๋ฌธ์ ์ฒ๋ฆฌ
|
| 355 |
+
result = hybrid_rag_processor.process_document(
|
| 356 |
+
content=content,
|
| 357 |
+
filename=filename,
|
| 358 |
+
user_id=user_id,
|
| 359 |
+
room_id=room_id
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
if result.get("success"):
|
| 363 |
+
processing_time = time.time() - start_time
|
| 364 |
+
return {
|
| 365 |
+
"status": "success",
|
| 366 |
+
"document_id": result.get("document_id", ""),
|
| 367 |
+
"processing_time": processing_time,
|
| 368 |
+
"message": "๋ฉํฐ๋ชจ๋ฌ ๋ฌธ์ ์
๋ก๋ ์๋ฃ"
|
| 369 |
+
}
|
| 370 |
+
else:
|
| 371 |
+
raise HTTPException(status_code=500, detail=result.get("error", "๋ฉํฐ๋ชจ๋ฌ ๋ฌธ์ ์ฒ๋ฆฌ ์คํจ"))
|
| 372 |
+
|
| 373 |
+
except ImportError:
|
| 374 |
+
raise HTTPException(status_code=500, detail="Hybrid RAG processor not available")
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(f"๋ฉํฐ๋ชจ๋ฌ ๋ฌธ์ ์
๋ก๋ ์คํจ: {e}")
|
| 378 |
+
raise HTTPException(status_code=500, detail=f"๋ฉํฐ๋ชจ๋ฌ ๋ฌธ์ ์
๋ก๋ ์คํจ: {str(e)}")
|
| 379 |
+
|
| 380 |
+
@router.post("/multimodal-rag/generate", response_model=MultimodalRAGResponse)
|
| 381 |
+
async def generate_multimodal_rag(
|
| 382 |
+
prompt: str = Form(...),
|
| 383 |
+
user_id: str = Form("anonymous"),
|
| 384 |
+
room_id: str = Form("default")
|
| 385 |
+
):
|
| 386 |
+
"""๋ฉํฐ๋ชจ๋ฌ RAG ๊ธฐ๋ฐ ํ
์คํธ ์์ฑ"""
|
| 387 |
+
try:
|
| 388 |
+
start_time = time.time()
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
from lily_llm_core.hybrid_rag_processor import hybrid_rag_processor
|
| 392 |
+
|
| 393 |
+
# ๋ฉํฐ๋ชจ๋ฌ RAG ์์ฑ
|
| 394 |
+
result = hybrid_rag_processor.generate(
|
| 395 |
+
prompt=prompt,
|
| 396 |
+
user_id=user_id,
|
| 397 |
+
room_id=room_id
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
if result.get("success"):
|
| 401 |
+
processing_time = time.time() - start_time
|
| 402 |
+
return MultimodalRAGResponse(
|
| 403 |
+
success=True,
|
| 404 |
+
response=result.get("response", ""),
|
| 405 |
+
image_processed=result.get("image_processed", False),
|
| 406 |
+
processing_time=processing_time
|
| 407 |
+
)
|
| 408 |
+
else:
|
| 409 |
+
return MultimodalRAGResponse(
|
| 410 |
+
success=False,
|
| 411 |
+
response="",
|
| 412 |
+
image_processed=False,
|
| 413 |
+
processing_time=0,
|
| 414 |
+
error=result.get("error", "๋ฉํฐ๋ชจ๋ฌ RAG ์์ฑ ์คํจ")
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
except ImportError:
|
| 418 |
+
return MultimodalRAGResponse(
|
| 419 |
+
success=False,
|
| 420 |
+
response="",
|
| 421 |
+
image_processed=False,
|
| 422 |
+
processing_time=0,
|
| 423 |
+
error="Hybrid RAG processor not available"
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
except Exception as e:
|
| 427 |
+
logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ์์ฑ ์คํจ: {e}")
|
| 428 |
+
return MultimodalRAGResponse(
|
| 429 |
+
success=False,
|
| 430 |
+
response="",
|
| 431 |
+
image_processed=False,
|
| 432 |
+
processing_time=0,
|
| 433 |
+
error=str(e)
|
| 434 |
+
)
|
lily_llm_api/api/routers/generation_router.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Request, UploadFile, File, Form, Depends
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from ...models.schemas import GenerateResponse, MultimodalGenerateResponse
|
| 10 |
+
from ...services.generation_service import generate_sync
|
| 11 |
+
from ...services.model_service import is_model_loaded
|
| 12 |
+
from ...utils.system_utils import select_model_interactive
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
@router.post("/generate", response_model=GenerateResponse)
|
| 18 |
+
async def generate(request: Request,
|
| 19 |
+
prompt: str = Form(...),
|
| 20 |
+
image1: UploadFile = File(None),
|
| 21 |
+
image2: UploadFile = File(None),
|
| 22 |
+
image3: UploadFile = File(None),
|
| 23 |
+
image4: UploadFile = File(None),
|
| 24 |
+
user_id: str = Form("anonymous"),
|
| 25 |
+
room_id: str = Form("default"),
|
| 26 |
+
use_context: bool = Form(True),
|
| 27 |
+
session_id: str = Form(None)):
|
| 28 |
+
|
| 29 |
+
if not is_model_loaded():
|
| 30 |
+
raise HTTPException(status_code=503, detail="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค.")
|
| 31 |
+
|
| 32 |
+
start_time = time.time()
|
| 33 |
+
|
| 34 |
+
# ์ธ์
ID๊ฐ ์์ผ๋ฉด ์๋ ์์ฑ (์ฑํ
๋ฐฉ๋ณ ๊ณ ์ ์ธ์
)
|
| 35 |
+
if not session_id:
|
| 36 |
+
# ์ฑํ
๋ฐฉ + ์ฌ์ฉ์ + ํ์์คํฌํ ๊ธฐ๋ฐ์ผ๋ก ๊ณ ์ ํ ์ธ์
์์ฑ
|
| 37 |
+
timestamp = int(time.time())
|
| 38 |
+
session_id = f"room_{room_id}_user_{user_id}_{timestamp}"
|
| 39 |
+
print(f"๐ [DEBUG] ์๋ ์ธ์
ID ์์ฑ: {session_id} (์ฑํ
๋ฐฉ: {room_id}, ์ฌ์ฉ์: {user_id})")
|
| 40 |
+
|
| 41 |
+
if use_context:
|
| 42 |
+
try:
|
| 43 |
+
from lily_llm_core.context_manager import context_manager
|
| 44 |
+
context_manager.add_user_message(prompt, metadata={"session_id": session_id})
|
| 45 |
+
print(f"๐ [DEBUG] ์ฌ์ฉ์ ๋ฉ์์ง ์ถ๊ฐ๋จ (์ธ์
: {session_id})")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์ ์ฌ์ฉ ๋ถ๊ฐ: {e}")
|
| 48 |
+
|
| 49 |
+
# ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
| 50 |
+
image_data_list = []
|
| 51 |
+
for img_file in [image1, image2, image3, image4]:
|
| 52 |
+
if img_file:
|
| 53 |
+
try:
|
| 54 |
+
data = await img_file.read()
|
| 55 |
+
image_data_list.append(data)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.warning(f"์ด๋ฏธ์ง ๋ก๋ ์คํจ: {e}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
# generate_sync ํจ์ ํธ์ถ (์ปจํ
์คํธ ํฌํจ)
|
| 61 |
+
result = generate_sync(prompt, image_data_list, use_context=use_context, session_id=session_id, user_id=user_id, room_id=room_id)
|
| 62 |
+
|
| 63 |
+
if "error" in result:
|
| 64 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
| 65 |
+
|
| 66 |
+
if use_context:
|
| 67 |
+
try:
|
| 68 |
+
from lily_llm_core.context_manager import context_manager
|
| 69 |
+
context_manager.add_assistant_message(result["generated_text"], metadata={"session_id": session_id})
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์ ์ฌ์ฉ ๋ถ๊ฐ: {e}")
|
| 72 |
+
|
| 73 |
+
return GenerateResponse(
|
| 74 |
+
generated_text=result["generated_text"],
|
| 75 |
+
processing_time=result["processing_time"],
|
| 76 |
+
model_name=result["model_name"],
|
| 77 |
+
image_processed=result["image_processed"]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}", exc_info=True)
|
| 82 |
+
raise HTTPException(status_code=500, detail=f"๋ชจ๋ธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@router.post("/generate-multimodal", response_model=MultimodalGenerateResponse)
|
| 86 |
+
async def generate_multimodal(prompt: str = Form(...),
|
| 87 |
+
image: UploadFile = File(None),
|
| 88 |
+
model_id: Optional[str] = Form(None),
|
| 89 |
+
max_length: Optional[int] = Form(None),
|
| 90 |
+
temperature: Optional[float] = Form(None),
|
| 91 |
+
top_p: Optional[float] = Form(None),
|
| 92 |
+
do_sample: Optional[bool] = Form(None)):
|
| 93 |
+
|
| 94 |
+
if not is_model_loaded():
|
| 95 |
+
raise HTTPException(status_code=500, detail="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค")
|
| 96 |
+
|
| 97 |
+
start_time = time.time()
|
| 98 |
+
|
| 99 |
+
# ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ฒ๋ฆฌ
|
| 100 |
+
image_data_list = []
|
| 101 |
+
if image:
|
| 102 |
+
try:
|
| 103 |
+
data = await image.read()
|
| 104 |
+
image_data_list.append(data)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"์ด๋ฏธ์ง ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
# generate_sync ํจ์ ํธ์ถ
|
| 110 |
+
result = generate_sync(prompt, image_data_list, max_length=max_length,
|
| 111 |
+
temperature=temperature, top_p=top_p, do_sample=do_sample)
|
| 112 |
+
|
| 113 |
+
if "error" in result:
|
| 114 |
+
raise HTTPException(status_code=500, detail=result["error"])
|
| 115 |
+
|
| 116 |
+
from ...services.model_service import get_current_profile
|
| 117 |
+
current_profile = get_current_profile()
|
| 118 |
+
|
| 119 |
+
return MultimodalGenerateResponse(
|
| 120 |
+
generated_text=result["generated_text"],
|
| 121 |
+
processing_time=result["processing_time"],
|
| 122 |
+
model_name=current_profile.display_name,
|
| 123 |
+
model_id=model_id or current_profile.get_model_info().get("model_name"),
|
| 124 |
+
image_processed=bool(image_data_list)
|
| 125 |
+
)
|
| 126 |
+
except Exception as e:
|
| 127 |
+
logger.error(f"โ ๋ฉํฐ๋ชจ๋ฌ ์์ฑ ์ค๋ฅ: {e}")
|
| 128 |
+
raise HTTPException(status_code=500, detail=f"๋ฉํฐ๋ชจ๋ฌ ์์ฑ ์คํจ: {str(e)}")
|
lily_llm_api/api/routers/lora_router.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Form, UploadFile, File
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from ...models.schemas import LoRAStatusResponse
|
| 9 |
+
from ...utils.lora_utils import setup_lora_for_model
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
router = APIRouter()
|
| 13 |
+
|
| 14 |
+
@router.get("/lora/status", response_model=LoRAStatusResponse)
|
| 15 |
+
async def get_lora_status():
|
| 16 |
+
"""ํ์ฌ LoRA ์ํ ํ์ธ"""
|
| 17 |
+
try:
|
| 18 |
+
try:
|
| 19 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 20 |
+
if lora_manager is None:
|
| 21 |
+
return LoRAStatusResponse(
|
| 22 |
+
status="error",
|
| 23 |
+
lora_available=False,
|
| 24 |
+
base_model_loaded=False,
|
| 25 |
+
device="unknown",
|
| 26 |
+
message="LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
return LoRAStatusResponse(
|
| 30 |
+
status="success",
|
| 31 |
+
lora_available=True,
|
| 32 |
+
current_adapter=getattr(lora_manager, 'current_adapter_name', None),
|
| 33 |
+
base_model_loaded=hasattr(lora_manager, 'base_model') and lora_manager.base_model is not None,
|
| 34 |
+
device=getattr(lora_manager, 'device', 'unknown')
|
| 35 |
+
)
|
| 36 |
+
except ImportError:
|
| 37 |
+
return LoRAStatusResponse(
|
| 38 |
+
status="error",
|
| 39 |
+
lora_available=False,
|
| 40 |
+
base_model_loaded=False,
|
| 41 |
+
device="unknown",
|
| 42 |
+
message="LoRA ๊ด๋ฆฌ์ import ์คํจ"
|
| 43 |
+
)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"LoRA ์ํ ํ์ธ ์คํจ: {e}")
|
| 46 |
+
return LoRAStatusResponse(
|
| 47 |
+
status="error",
|
| 48 |
+
lora_available=False,
|
| 49 |
+
base_model_loaded=False,
|
| 50 |
+
device="unknown",
|
| 51 |
+
message=str(e)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
@router.post("/lora/load-base-model")
|
| 55 |
+
async def load_base_model(model_path: str = Form(...)):
|
| 56 |
+
"""LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋"""
|
| 57 |
+
try:
|
| 58 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 59 |
+
if not lora_manager:
|
| 60 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 61 |
+
|
| 62 |
+
success = lora_manager.load_base_model(model_path)
|
| 63 |
+
if success:
|
| 64 |
+
return {"status": "success", "message": f"๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋ ์๋ฃ: {model_path}"}
|
| 65 |
+
else:
|
| 66 |
+
raise HTTPException(status_code=500, detail="๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋ ์คํจ")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.error(f"LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 69 |
+
raise HTTPException(status_code=500, detail=f"LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}")
|
| 70 |
+
|
| 71 |
+
@router.post("/lora/create-config")
|
| 72 |
+
async def create_lora_config(
|
| 73 |
+
r: int = Form(16),
|
| 74 |
+
lora_alpha: int = Form(32),
|
| 75 |
+
lora_dropout: float = Form(0.1),
|
| 76 |
+
bias: str = Form("none"),
|
| 77 |
+
task_type: str = Form("CAUSAL_LM"),
|
| 78 |
+
target_modules: str = Form("query_key_value")
|
| 79 |
+
):
|
| 80 |
+
"""LoRA ์ค์ ์์ฑ"""
|
| 81 |
+
try:
|
| 82 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 83 |
+
if not lora_manager:
|
| 84 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 85 |
+
|
| 86 |
+
# target_modules๋ฅผ ๋ฆฌ์คํธ๋ก ๋ณํ
|
| 87 |
+
target_modules_list = [m.strip() for m in target_modules.split(",")]
|
| 88 |
+
|
| 89 |
+
config = lora_manager.create_lora_config(
|
| 90 |
+
r=r,
|
| 91 |
+
lora_alpha=lora_alpha,
|
| 92 |
+
lora_dropout=lora_dropout,
|
| 93 |
+
bias=bias,
|
| 94 |
+
task_type=task_type,
|
| 95 |
+
target_modules=target_modules_list
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return {"status": "success", "config": config}
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"LoRA ์ค์ ์์ฑ ์คํจ: {e}")
|
| 101 |
+
raise HTTPException(status_code=500, detail=f"LoRA ์ค์ ์์ฑ ์คํจ: {str(e)}")
|
| 102 |
+
|
| 103 |
+
@router.post("/lora/apply")
|
| 104 |
+
async def apply_lora(adapter_name: str = Form(...)):
|
| 105 |
+
"""LoRA ์ด๋ํฐ ์ ์ฉ"""
|
| 106 |
+
try:
|
| 107 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 108 |
+
if not lora_manager:
|
| 109 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 110 |
+
|
| 111 |
+
success = lora_manager.apply_lora_to_model(adapter_name)
|
| 112 |
+
if success:
|
| 113 |
+
return {"status": "success", "message": f"LoRA ์ด๋ํฐ ์ ์ฉ ์๋ฃ: {adapter_name}"}
|
| 114 |
+
else:
|
| 115 |
+
raise HTTPException(status_code=500, detail="LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ")
|
| 116 |
+
except Exception as e:
|
| 117 |
+
logger.error(f"LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ: {e}")
|
| 118 |
+
raise HTTPException(status_code=500, detail=f"LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ: {str(e)}")
|
| 119 |
+
|
| 120 |
+
@router.get("/lora/adapters")
|
| 121 |
+
async def list_lora_adapters():
|
| 122 |
+
"""์ฌ์ฉ ๊ฐ๋ฅํ LoRA ์ด๋ํฐ ๋ชฉ๋ก"""
|
| 123 |
+
try:
|
| 124 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 125 |
+
if not lora_manager:
|
| 126 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 127 |
+
|
| 128 |
+
adapters = lora_manager.list_adapters()
|
| 129 |
+
return {"status": "success", "adapters": adapters}
|
| 130 |
+
except Exception as e:
|
| 131 |
+
logger.error(f"LoRA ์ด๋ํฐ ๋ชฉ๋ก ์กฐํ ์คํจ: {e}")
|
| 132 |
+
raise HTTPException(status_code=500, detail=f"LoRA ์ด๋ํฐ ๋ชฉ๋ก ์กฐํ ์คํจ: {str(e)}")
|
| 133 |
+
|
| 134 |
+
@router.get("/lora/stats")
|
| 135 |
+
async def get_lora_stats():
|
| 136 |
+
"""LoRA ํต๊ณ ์ ๋ณด"""
|
| 137 |
+
try:
|
| 138 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 139 |
+
if not lora_manager:
|
| 140 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 141 |
+
|
| 142 |
+
stats = lora_manager.get_stats()
|
| 143 |
+
return {"status": "success", "stats": stats}
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"LoRA ํต๊ณ ์กฐํ ์คํจ: {e}")
|
| 146 |
+
raise HTTPException(status_code=500, detail=f"LoRA ํต๊ณ ์กฐํ ์คํจ: {str(e)}")
|
| 147 |
+
|
| 148 |
+
@router.post("/lora/switch")
|
| 149 |
+
async def switch_lora_adapter(adapter_name: str = Form(...)):
|
| 150 |
+
"""LoRA ์ด๋ํฐ ์ ํ"""
|
| 151 |
+
try:
|
| 152 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 153 |
+
if not lora_manager:
|
| 154 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 155 |
+
|
| 156 |
+
success = lora_manager.switch_adapter(adapter_name)
|
| 157 |
+
if success:
|
| 158 |
+
return {"status": "success", "message": f"LoRA ์ด๋ํฐ ์ ํ ์๋ฃ: {adapter_name}"}
|
| 159 |
+
else:
|
| 160 |
+
raise HTTPException(status_code=500, detail="LoRA ์ด๋ํฐ ์ ํ ์คํจ")
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"LoRA ์ด๋ํฐ ์ ํ ์คํจ: {e}")
|
| 163 |
+
raise HTTPException(status_code=500, detail=f"LoRA ์ด๋ํฐ ์ ํ ์คํจ: {str(e)}")
|
| 164 |
+
|
| 165 |
+
@router.post("/lora/unload")
|
| 166 |
+
async def unload_lora_adapter():
|
| 167 |
+
"""LoRA ์ด๋ํฐ ์ธ๋ก๋"""
|
| 168 |
+
try:
|
| 169 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 170 |
+
if not lora_manager:
|
| 171 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 172 |
+
|
| 173 |
+
success = lora_manager.unload_adapter()
|
| 174 |
+
if success:
|
| 175 |
+
return {"status": "success", "message": "LoRA ์ด๋ํฐ ์ธ๋ก๋ ์๋ฃ"}
|
| 176 |
+
else:
|
| 177 |
+
raise HTTPException(status_code=500, detail="LoRA ์ด๋ํฐ ์ธ๋ก๋ ์คํจ")
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"LoRA ์ด๋ํฐ ์ธ๋ก๋ ์คํจ: {e}")
|
| 180 |
+
raise HTTPException(status_code=500, detail=f"LoRA ์ด๋ํฐ ์ธ๋ก๋ ์คํจ: {str(e)}")
|
| 181 |
+
|
| 182 |
+
@router.post("/lora/generate")
|
| 183 |
+
async def generate_with_lora(
|
| 184 |
+
prompt: str = Form(...),
|
| 185 |
+
max_length: int = Form(100),
|
| 186 |
+
temperature: float = Form(0.7)
|
| 187 |
+
):
|
| 188 |
+
"""LoRA๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ๋ก ํ
์คํธ ์์ฑ"""
|
| 189 |
+
try:
|
| 190 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 191 |
+
if not lora_manager:
|
| 192 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 193 |
+
|
| 194 |
+
if not lora_manager.current_adapter_name:
|
| 195 |
+
raise HTTPException(status_code=400, detail="๋ก๋๋ LoRA ์ด๋ํฐ๊ฐ ์์ต๋๋ค")
|
| 196 |
+
|
| 197 |
+
result = lora_manager.generate_text(
|
| 198 |
+
prompt=prompt,
|
| 199 |
+
max_length=max_length,
|
| 200 |
+
temperature=temperature
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return {"status": "success", "generated_text": result}
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.error(f"LoRA ํ
์คํธ ์์ฑ ์คํจ: {e}")
|
| 206 |
+
raise HTTPException(status_code=500, detail=f"LoRA ํ
์คํธ ์์ฑ ์คํจ: {str(e)}")
|
| 207 |
+
|
| 208 |
+
@router.post("/lora/merge")
|
| 209 |
+
async def merge_lora_with_base():
|
| 210 |
+
"""LoRA ์ด๋ํฐ๋ฅผ ๊ธฐ๋ณธ ๋ชจ๋ธ๊ณผ ๋ณํฉ"""
|
| 211 |
+
try:
|
| 212 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 213 |
+
if not lora_manager:
|
| 214 |
+
raise HTTPException(status_code=500, detail="LoRA ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค")
|
| 215 |
+
|
| 216 |
+
success = lora_manager.merge_adapter_with_base()
|
| 217 |
+
if success:
|
| 218 |
+
return {"status": "success", "message": "LoRA ์ด๋ํฐ ๋ณํฉ ์๋ฃ"}
|
| 219 |
+
else:
|
| 220 |
+
raise HTTPException(status_code=500, detail="LoRA ์ด๋ํฐ ๋ณํฉ ์คํจ")
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"LoRA ์ด๋ํฐ ๋ณํฉ ์คํจ: {e}")
|
| 223 |
+
raise HTTPException(status_code=500, detail=f"LoRA ์ด๋ํฐ ๋ณํฉ ์คํจ: {str(e)}")
|
lily_llm_api/api/routers/model_router.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Form
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from ...models.schemas import HealthResponse
|
| 9 |
+
from ...services.model_service import load_model_async, get_current_profile, is_model_loaded
|
| 10 |
+
from ...models import list_available_models
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
router = APIRouter()
|
| 14 |
+
|
| 15 |
+
@router.post("/load-model")
|
| 16 |
+
async def load_model_endpoint(model_id: str):
|
| 17 |
+
"""๋ชจ๋ธ ๋ก๋ HTTP ์๋ํฌ์ธํธ"""
|
| 18 |
+
try:
|
| 19 |
+
logger.info(f"๐ฅ HTTP ์์ฒญ์ผ๋ก ๋ชจ๋ธ ๋ก๋ ์์: {model_id}")
|
| 20 |
+
await load_model_async(model_id)
|
| 21 |
+
return {"success": True, "message": f"๋ชจ๋ธ '{model_id}' ๋ก๋ ์๋ฃ"}
|
| 22 |
+
except Exception as e:
|
| 23 |
+
logger.error(f"โ HTTP ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}")
|
| 24 |
+
return {"success": False, "error": str(e)}
|
| 25 |
+
|
| 26 |
+
@router.get("/models")
|
| 27 |
+
async def list_models():
|
| 28 |
+
"""์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก"""
|
| 29 |
+
return {
|
| 30 |
+
"models": list_available_models(),
|
| 31 |
+
"current_model": get_current_profile().get_model_info() if get_current_profile() else None
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
@router.post("/switch-model")
|
| 35 |
+
async def switch_model(model_id: str):
|
| 36 |
+
"""๋ชจ๋ธ ๋ณ๊ฒฝ"""
|
| 37 |
+
try:
|
| 38 |
+
await load_model_async(model_id)
|
| 39 |
+
return {
|
| 40 |
+
"message": f"๋ชจ๋ธ ๋ณ๊ฒฝ ์ฑ๊ณต: {model_id}",
|
| 41 |
+
"current_model": get_current_profile().display_name
|
| 42 |
+
}
|
| 43 |
+
except Exception as e:
|
| 44 |
+
raise HTTPException(status_code=500, detail=f"๋ชจ๋ธ ๋ณ๊ฒฝ ์คํจ: {str(e)}")
|
| 45 |
+
|
| 46 |
+
@router.get("/health", response_model=HealthResponse)
|
| 47 |
+
async def health_check():
|
| 48 |
+
"""ํฌ์ค ์ฒดํฌ ์๋ํฌ์ธํธ"""
|
| 49 |
+
available_models = list_available_models()
|
| 50 |
+
|
| 51 |
+
return HealthResponse(
|
| 52 |
+
status="healthy",
|
| 53 |
+
model_loaded=is_model_loaded(),
|
| 54 |
+
current_model=get_current_profile().display_name if get_current_profile() else "None",
|
| 55 |
+
available_models=available_models
|
| 56 |
+
)
|
lily_llm_api/api/routers/multimodal_rag_router.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multimodal RAG router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from ...models.schemas import DocumentUploadResponse, RAGResponse
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
router = APIRouter()
|
| 14 |
+
|
| 15 |
+
# ============================================================================
|
| 16 |
+
# ๋ฉํฐ๋ชจ๋ฌ RAG ์์คํ
์๋ํฌ์ธํธ
|
| 17 |
+
# ============================================================================
|
| 18 |
+
|
| 19 |
+
@router.post("/hybrid-rag/upload", response_model=DocumentUploadResponse)
|
| 20 |
+
async def upload_hybrid_document(
|
| 21 |
+
file: UploadFile = File(...),
|
| 22 |
+
user_id: str = Form("default_user"),
|
| 23 |
+
document_id: Optional[str] = Form(None)
|
| 24 |
+
):
|
| 25 |
+
"""๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์
๋ก๋"""
|
| 26 |
+
try:
|
| 27 |
+
# ํ์ผ ์ ์ฅ
|
| 28 |
+
upload_dir = Path("uploads/hybrid_rag")
|
| 29 |
+
upload_dir.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
if not document_id:
|
| 32 |
+
document_id = f"{user_id}_{int(time.time())}_{file.filename}"
|
| 33 |
+
|
| 34 |
+
file_path = upload_dir / document_id
|
| 35 |
+
with open(file_path, "wb") as buffer:
|
| 36 |
+
content = await file.read()
|
| 37 |
+
buffer.write(content)
|
| 38 |
+
|
| 39 |
+
# ๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ
|
| 40 |
+
try:
|
| 41 |
+
from lily_llm_core.hybrid_rag_processor import hybrid_rag_processor
|
| 42 |
+
result = hybrid_rag_processor.process_document_hybrid(str(file_path), user_id, document_id)
|
| 43 |
+
except ImportError:
|
| 44 |
+
result = {
|
| 45 |
+
"success": False,
|
| 46 |
+
"error": "Hybrid RAG processor not available"
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
if result["success"]:
|
| 50 |
+
# ์ฑ๊ณตํ ์์คํ
์ ๊ณ์ฐ
|
| 51 |
+
success_systems = []
|
| 52 |
+
for key, value in result.items():
|
| 53 |
+
if key.endswith('_processing') and value and value.get('success', False):
|
| 54 |
+
system_name = key.replace('_processing', '').replace('_', ' ').title()
|
| 55 |
+
success_systems.append(system_name)
|
| 56 |
+
|
| 57 |
+
return DocumentUploadResponse(
|
| 58 |
+
success=True,
|
| 59 |
+
document_id=document_id,
|
| 60 |
+
message=f"๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ ์๋ฃ: {', '.join(success_systems)} ์์คํ
์์ ์ฒ๋ฆฌ๋จ",
|
| 61 |
+
chunks=len(success_systems)
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
return DocumentUploadResponse(
|
| 65 |
+
success=False,
|
| 66 |
+
error=result.get("error", "๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ ์คํจ")
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์
๋ก๋ ์ค๋ฅ: {e}")
|
| 71 |
+
return DocumentUploadResponse(
|
| 72 |
+
success=False,
|
| 73 |
+
error=f"์
๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
@router.post("/hybrid-rag/generate", response_model=RAGResponse)
|
| 77 |
+
async def generate_hybrid_rag_response(
|
| 78 |
+
query: str = Form(...),
|
| 79 |
+
user_id: str = Form("default_user"),
|
| 80 |
+
document_id: str = Form(...),
|
| 81 |
+
use_text: bool = Form(True),
|
| 82 |
+
use_image: bool = Form(True),
|
| 83 |
+
use_latex: bool = Form(True),
|
| 84 |
+
use_latex_ocr: bool = Form(False), # LaTeX-OCR ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ
|
| 85 |
+
max_length: Optional[int] = Form(None),
|
| 86 |
+
temperature: Optional[float] = Form(None),
|
| 87 |
+
top_p: Optional[float] = Form(None),
|
| 88 |
+
do_sample: Optional[bool] = Form(None)
|
| 89 |
+
):
|
| 90 |
+
"""๋ฉํฐ๋ชจ๋ฌ RAG ์๋ต ์์ฑ"""
|
| 91 |
+
try:
|
| 92 |
+
try:
|
| 93 |
+
from lily_llm_core.hybrid_rag_processor import hybrid_rag_processor
|
| 94 |
+
result = hybrid_rag_processor.generate_hybrid_response(
|
| 95 |
+
query, user_id, document_id,
|
| 96 |
+
use_text, use_image, use_latex, use_latex_ocr,
|
| 97 |
+
max_length, temperature, top_p, do_sample
|
| 98 |
+
)
|
| 99 |
+
except ImportError:
|
| 100 |
+
result = {
|
| 101 |
+
"success": False,
|
| 102 |
+
"response": "Hybrid RAG processor not available",
|
| 103 |
+
"context": "",
|
| 104 |
+
"sources": [],
|
| 105 |
+
"search_results": 0,
|
| 106 |
+
"processing_time": 0.0
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
return RAGResponse(
|
| 110 |
+
success=result["success"],
|
| 111 |
+
response=result["response"],
|
| 112 |
+
context=result["context"],
|
| 113 |
+
sources=result["sources"],
|
| 114 |
+
search_results=result["search_results"],
|
| 115 |
+
processing_time=result["processing_time"]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ์๋ต ์์ฑ ์ค๋ฅ: {e}")
|
| 120 |
+
return RAGResponse(
|
| 121 |
+
success=False,
|
| 122 |
+
response=f"๋ฉํฐ๋ชจ๋ฌ RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}",
|
| 123 |
+
context="",
|
| 124 |
+
sources=[],
|
| 125 |
+
search_results=0,
|
| 126 |
+
processing_time=0.0
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
@router.get("/hybrid-rag/document/{user_id}/{document_id}")
|
| 130 |
+
async def get_hybrid_document_info(user_id: str, document_id: str):
|
| 131 |
+
"""๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์ ๋ณด ์กฐํ"""
|
| 132 |
+
try:
|
| 133 |
+
try:
|
| 134 |
+
from lily_llm_core.hybrid_rag_processor import hybrid_rag_processor
|
| 135 |
+
result = hybrid_rag_processor.get_document_info(user_id, document_id)
|
| 136 |
+
except ImportError:
|
| 137 |
+
result = {"success": False, "error": "Hybrid RAG processor not available"}
|
| 138 |
+
return result
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์ ๋ณด ์กฐํ ์ค๋ฅ: {e}")
|
| 141 |
+
return {"success": False, "error": str(e)}
|
| 142 |
+
|
| 143 |
+
@router.get("/hybrid-rag/status")
|
| 144 |
+
async def get_hybrid_rag_status():
|
| 145 |
+
"""๋ฉํฐ๋ชจ๋ฌ RAG ์์คํ
์ํ ํ์ธ"""
|
| 146 |
+
try:
|
| 147 |
+
return {
|
| 148 |
+
"text_rag_available": True,
|
| 149 |
+
"image_rag_available": True,
|
| 150 |
+
"latex_rag_available": True,
|
| 151 |
+
"latex_ocr_faiss_available": False, # LaTeX-OCR ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ
|
| 152 |
+
"status": "ready"
|
| 153 |
+
}
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ์ํ ํ์ธ ์ค๋ฅ: {e}")
|
| 156 |
+
return {"status": "error", "error": str(e)}
|
| 157 |
+
|
| 158 |
+
# ============================================================================
|
| 159 |
+
# RAG ์์คํ
๊ณผ ๊ณ ๊ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์ ํตํฉ API
|
| 160 |
+
# ============================================================================
|
| 161 |
+
|
| 162 |
+
@router.post("/rag/context-integrated/query")
|
| 163 |
+
async def rag_query_with_context_integration(
|
| 164 |
+
user_id: str = Form(...),
|
| 165 |
+
document_id: str = Form(...),
|
| 166 |
+
query: str = Form(...),
|
| 167 |
+
session_id: str = Form(...),
|
| 168 |
+
max_results: int = Form(5),
|
| 169 |
+
enable_context_integration: bool = Form(True)
|
| 170 |
+
):
|
| 171 |
+
"""RAG ์ฟผ๋ฆฌ + ์ปจํ
์คํธ ํตํฉ - ๊ณ ๊ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์์ ์ฐ๋"""
|
| 172 |
+
try:
|
| 173 |
+
logger.info(f"๐ RAG + ์ปจํ
์คํธ ํตํฉ ์ฟผ๋ฆฌ ์์: ์ฌ์ฉ์ {user_id}, ๋ฌธ์ {document_id}, ์ธ์
{session_id}")
|
| 174 |
+
|
| 175 |
+
# ์ปจํ
์คํธ ๊ด๋ฆฌ์ ํ์ธ
|
| 176 |
+
try:
|
| 177 |
+
from lily_llm_core.context_manager import context_manager
|
| 178 |
+
if not context_manager:
|
| 179 |
+
return {"status": "error", "message": "์ปจํ
์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."}
|
| 180 |
+
except ImportError:
|
| 181 |
+
return {"status": "error", "message": "Context manager not available"}
|
| 182 |
+
|
| 183 |
+
# RAG ์๋ต ์์ฑ (์ปจํ
์คํธ ํตํฉ ํ์ฑํ)
|
| 184 |
+
try:
|
| 185 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 186 |
+
rag_result = rag_processor.generate_rag_response(
|
| 187 |
+
user_id=user_id,
|
| 188 |
+
document_id=document_id,
|
| 189 |
+
query=query,
|
| 190 |
+
session_id=session_id if enable_context_integration else None,
|
| 191 |
+
context_manager=context_manager if enable_context_integration else None
|
| 192 |
+
)
|
| 193 |
+
except ImportError:
|
| 194 |
+
rag_result = {"success": False, "error": "RAG processor not available"}
|
| 195 |
+
|
| 196 |
+
if not rag_result["success"]:
|
| 197 |
+
return rag_result
|
| 198 |
+
|
| 199 |
+
# ์ปจํ
์คํธ์ RAG ๊ฒฐ๊ณผ ํตํฉ
|
| 200 |
+
if enable_context_integration:
|
| 201 |
+
try:
|
| 202 |
+
# RAG ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ปจํ
์คํธ์ ์ถ๊ฐ
|
| 203 |
+
rag_summary = f"RAG ๊ฒ์ ๊ฒฐ๊ณผ: {query}์ ๋ํ {rag_result.get('search_results', 0)}๊ฐ ๊ด๋ จ ๋ฌธ์ ๋ฐ๊ฒฌ"
|
| 204 |
+
|
| 205 |
+
# ์ปจํ
์คํธ์ ์์คํ
๋ฉ์์ง๋ก ์ถ๊ฐ
|
| 206 |
+
context_manager.add_system_message(
|
| 207 |
+
rag_summary,
|
| 208 |
+
metadata={"session_id": session_id, "type": "rag_integration", "query": query}
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
logger.info(f"๐ RAG ๊ฒฐ๊ณผ๋ฅผ ์ปจํ
์คํธ์ ํตํฉ ์๋ฃ (์ธ์
: {session_id})")
|
| 212 |
+
|
| 213 |
+
except Exception as e:
|
| 214 |
+
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ํตํฉ ์คํจ: {e}")
|
| 215 |
+
|
| 216 |
+
# ํตํฉ๋ ๊ฒฐ๊ณผ ๋ฐํ
|
| 217 |
+
result = {
|
| 218 |
+
"status": "success",
|
| 219 |
+
"rag_response": rag_result,
|
| 220 |
+
"context_integration": enable_context_integration,
|
| 221 |
+
"session_id": session_id,
|
| 222 |
+
"context_summary": context_manager.get_context_summary(session_id) if enable_context_integration else None
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
logger.info(f"โ
RAG + ์ปจํ
์คํธ ํตํฉ ์ฟผ๋ฆฌ ์๋ฃ")
|
| 226 |
+
return result
|
| 227 |
+
|
| 228 |
+
except Exception as e:
|
| 229 |
+
logger.error(f"โ RAG + ์ปจํ
์คํธ ํตํฉ ์ฟผ๋ฆฌ ์คํจ: {e}")
|
| 230 |
+
return {"status": "error", "message": str(e)}
|
| 231 |
+
|
| 232 |
+
@router.get("/rag/context-integrated/summary/{session_id}")
|
| 233 |
+
async def get_rag_context_summary(session_id: str):
|
| 234 |
+
"""RAG ํตํฉ ์ปจํ
์คํธ ์์ฝ ์กฐํ"""
|
| 235 |
+
try:
|
| 236 |
+
try:
|
| 237 |
+
from lily_llm_core.context_manager import context_manager
|
| 238 |
+
if not context_manager:
|
| 239 |
+
return {"status": "error", "message": "์ปจํ
์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."}
|
| 240 |
+
except ImportError:
|
| 241 |
+
return {"status": "error", "message": "Context manager not available"}
|
| 242 |
+
|
| 243 |
+
# ์ปจํ
์คํธ ์์ฝ ์ ๋ณด
|
| 244 |
+
context_summary = context_manager.get_context_summary(session_id)
|
| 245 |
+
|
| 246 |
+
# RAG ๊ด๋ จ ์ ๋ณด ์ถ์ถ
|
| 247 |
+
rag_contexts = []
|
| 248 |
+
if session_id in context_manager.session_conversations:
|
| 249 |
+
for turn in context_manager.session_conversations[session_id]:
|
| 250 |
+
if (hasattr(turn, 'metadata') and turn.metadata and
|
| 251 |
+
turn.metadata.get('type') == 'rag_integration'):
|
| 252 |
+
rag_contexts.append({
|
| 253 |
+
"query": turn.metadata.get('query', ''),
|
| 254 |
+
"content": turn.content,
|
| 255 |
+
"timestamp": turn.timestamp
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
return {
|
| 259 |
+
"status": "success",
|
| 260 |
+
"session_id": session_id,
|
| 261 |
+
"context_summary": context_summary,
|
| 262 |
+
"rag_contexts": rag_contexts,
|
| 263 |
+
"rag_context_count": len(rag_contexts)
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"โ RAG ์ปจํ
์คํธ ์์ฝ ์กฐํ ์คํจ: {e}")
|
| 268 |
+
return {"status": "error", "message": str(e)}
|
| 269 |
+
|
| 270 |
+
@router.post("/rag/context-integrated/clear/{session_id}")
|
| 271 |
+
async def clear_rag_context(session_id: str):
|
| 272 |
+
"""RAG ํตํฉ ์ปจํ
์คํธ ์ ๋ฆฌ"""
|
| 273 |
+
try:
|
| 274 |
+
try:
|
| 275 |
+
from lily_llm_core.context_manager import context_manager
|
| 276 |
+
if not context_manager:
|
| 277 |
+
return {"status": "error", "message": "์ปจํ
์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."}
|
| 278 |
+
except ImportError:
|
| 279 |
+
return {"status": "error", "message": "Context manager not available"}
|
| 280 |
+
|
| 281 |
+
# RAG ๊ด๋ จ ์ปจํ
์คํธ๋ง ์ ๊ฑฐ
|
| 282 |
+
if session_id in context_manager.session_conversations:
|
| 283 |
+
conversation_history = context_manager.session_conversations[session_id]
|
| 284 |
+
rag_turns = []
|
| 285 |
+
|
| 286 |
+
for turn in conversation_history:
|
| 287 |
+
if (hasattr(turn, 'metadata') and turn.metadata and
|
| 288 |
+
turn.metadata.get('type') == 'rag_integration'):
|
| 289 |
+
rag_turns.append(turn)
|
| 290 |
+
|
| 291 |
+
# RAG ๊ด๋ จ ํด ์ ๊ฑฐ
|
| 292 |
+
for turn in rag_turns:
|
| 293 |
+
context_manager.remove_message(turn.message_id, session_id)
|
| 294 |
+
|
| 295 |
+
logger.info(f"๐๏ธ RAG ์ปจํ
์คํธ ์ ๋ฆฌ ์๋ฃ: {len(rag_turns)}๊ฐ ํด ์ ๊ฑฐ (์ธ์
: {session_id})")
|
| 296 |
+
|
| 297 |
+
return {
|
| 298 |
+
"status": "success",
|
| 299 |
+
"session_id": session_id,
|
| 300 |
+
"removed_rag_turns": len(rag_turns),
|
| 301 |
+
"message": f"RAG ์ปจํ
์คํธ {len(rag_turns)}๊ฐ ํด์ด ์ ๊ฑฐ๋์์ต๋๋ค."
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
return {
|
| 305 |
+
"status": "success",
|
| 306 |
+
"session_id": session_id,
|
| 307 |
+
"removed_rag_turns": 0,
|
| 308 |
+
"message": "์ ๊ฑฐํ RAG ์ปจํ
์คํธ๊ฐ ์์ต๋๋ค."
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
except Exception as e:
|
| 312 |
+
logger.error(f"โ RAG ์ปจํ
์คํธ ์ ๋ฆฌ ์คํจ: {e}")
|
| 313 |
+
return {"status": "error", "message": str(e)}
|
| 314 |
+
|
| 315 |
+
@router.get("/rag/performance/stats")
|
| 316 |
+
async def get_rag_performance_stats():
|
| 317 |
+
"""RAG ์์คํ
์ฑ๋ฅ ํต๊ณ ์กฐํ"""
|
| 318 |
+
try:
|
| 319 |
+
# RAG ํ๋ก์ธ์ ์ฑ๋ฅ ํต๊ณ
|
| 320 |
+
try:
|
| 321 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 322 |
+
rag_stats = rag_processor.get_performance_stats()
|
| 323 |
+
except ImportError:
|
| 324 |
+
rag_stats = {"total_requests": 0, "success_rate": 0.0, "avg_processing_time": 0.0}
|
| 325 |
+
|
| 326 |
+
# ๋ฒกํฐ ์คํ ์ด ์ฑ๋ฅ ํต๊ณ
|
| 327 |
+
try:
|
| 328 |
+
from lily_llm_core.vector_store_manager import vector_store_manager
|
| 329 |
+
vector_stats = vector_store_manager.get_performance_stats()
|
| 330 |
+
except ImportError:
|
| 331 |
+
vector_stats = {"total_operations": 0, "success_rate": 0.0, "avg_operation_time": 0.0}
|
| 332 |
+
|
| 333 |
+
# ํตํฉ ์ฑ๋ฅ ํต๊ณ
|
| 334 |
+
combined_stats = {
|
| 335 |
+
"rag_processor": rag_stats,
|
| 336 |
+
"vector_store": vector_stats,
|
| 337 |
+
"overall": {
|
| 338 |
+
"total_operations": rag_stats.get("total_requests", 0) + vector_stats.get("total_operations", 0),
|
| 339 |
+
"success_rate": (rag_stats.get("success_rate", 0.0) + vector_stats.get("success_rate", 0.0)) / 2,
|
| 340 |
+
"avg_processing_time": (rag_stats.get("avg_processing_time", 0.0) + vector_stats.get("avg_operation_time", 0.0)) / 2
|
| 341 |
+
},
|
| 342 |
+
"timestamp": time.time()
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
return {
|
| 346 |
+
"status": "success",
|
| 347 |
+
"performance_stats": combined_stats
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(f"โ RAG ์ฑ๋ฅ ํต๊ณ ์กฐํ ์คํจ: {e}")
|
| 352 |
+
return {"status": "error", "message": str(e)}
|
| 353 |
+
|
| 354 |
+
@router.post("/rag/performance/reset")
|
| 355 |
+
async def reset_rag_performance_stats():
|
| 356 |
+
"""RAG ์์คํ
์ฑ๋ฅ ํต๊ณ ์ด๊ธฐํ"""
|
| 357 |
+
try:
|
| 358 |
+
# RAG ํ๋ก์ธ์ ํต๊ณ ์ด๊ธฐํ
|
| 359 |
+
try:
|
| 360 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 361 |
+
rag_processor.reset_stats()
|
| 362 |
+
except ImportError:
|
| 363 |
+
pass
|
| 364 |
+
|
| 365 |
+
# ๋ฒกํฐ ์คํ ์ด ํต๊ณ ์ด๊ธฐํ
|
| 366 |
+
try:
|
| 367 |
+
from lily_llm_core.vector_store_manager import vector_store_manager
|
| 368 |
+
vector_store_manager.reset_stats()
|
| 369 |
+
except ImportError:
|
| 370 |
+
pass
|
| 371 |
+
|
| 372 |
+
logger.info("๏ฟฝ๏ฟฝ๏ฟฝ๏ฟฝ RAG ์์คํ
์ฑ๋ฅ ํต๊ณ ์ด๊ธฐํ ์๋ฃ")
|
| 373 |
+
|
| 374 |
+
return {
|
| 375 |
+
"status": "success",
|
| 376 |
+
"message": "RAG ์์คํ
์ฑ๋ฅ ํต๊ณ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค."
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
except Exception as e:
|
| 380 |
+
logger.error(f"โ RAG ์ฑ๋ฅ ํต๊ณ ์ด๊ธฐํ ์คํจ: {e}")
|
| 381 |
+
return {"status": "error", "message": str(e)}
|
| 382 |
+
|
| 383 |
+
@router.get("/rag/health/check")
|
| 384 |
+
async def rag_health_check():
|
| 385 |
+
"""RAG ์์คํ
๊ฑด๊ฐ ์ํ ํ์ธ"""
|
| 386 |
+
try:
|
| 387 |
+
# RAG ํ๋ก์ธ์ ์ํ
|
| 388 |
+
try:
|
| 389 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 390 |
+
rag_status = {
|
| 391 |
+
"rag_processor": "healthy",
|
| 392 |
+
"enable_context_integration": getattr(rag_processor, 'enable_context_integration', False),
|
| 393 |
+
"max_context_length": getattr(rag_processor, 'max_context_length', 0),
|
| 394 |
+
"max_search_results": getattr(rag_processor, 'max_search_results', 0)
|
| 395 |
+
}
|
| 396 |
+
except ImportError:
|
| 397 |
+
rag_status = {"rag_processor": "not_available"}
|
| 398 |
+
|
| 399 |
+
# ๋ฒกํฐ ์คํ ์ด ์ํ
|
| 400 |
+
try:
|
| 401 |
+
from lily_llm_core.vector_store_manager import vector_store_manager
|
| 402 |
+
vector_status = vector_store_manager.health_check()
|
| 403 |
+
except ImportError:
|
| 404 |
+
vector_status = {"status": "not_available"}
|
| 405 |
+
|
| 406 |
+
# ๋ฌธ์ ํ๋ก์ธ์ ์ํ
|
| 407 |
+
try:
|
| 408 |
+
from lily_llm_core.document_processor import document_processor
|
| 409 |
+
doc_processor_status = {
|
| 410 |
+
"status": "healthy",
|
| 411 |
+
"supported_formats": getattr(document_processor, 'supported_formats', []),
|
| 412 |
+
"ocr_available": hasattr(document_processor, 'ocr_reader') and document_processor.ocr_reader is not None
|
| 413 |
+
}
|
| 414 |
+
except ImportError:
|
| 415 |
+
doc_processor_status = {"status": "not_available"}
|
| 416 |
+
|
| 417 |
+
# ํตํฉ ์ํ
|
| 418 |
+
overall_status = "healthy"
|
| 419 |
+
if vector_status.get("status") != "healthy":
|
| 420 |
+
overall_status = "degraded"
|
| 421 |
+
|
| 422 |
+
return {
|
| 423 |
+
"status": "success",
|
| 424 |
+
"overall_status": overall_status,
|
| 425 |
+
"rag_processor": rag_status,
|
| 426 |
+
"vector_store": vector_status,
|
| 427 |
+
"document_processor": doc_processor_status,
|
| 428 |
+
"timestamp": time.time()
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.error(f"โ RAG ์์คํ
๊ฑด๊ฐ ์ํ ํ์ธ ์คํจ: {e}")
|
| 433 |
+
return {
|
| 434 |
+
"status": "error",
|
| 435 |
+
"overall_status": "unhealthy",
|
| 436 |
+
"error": str(e),
|
| 437 |
+
"timestamp": time.time()
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
@router.post("/rag/context-integrated/batch-process")
|
| 441 |
+
async def batch_process_with_context_integration(
|
| 442 |
+
user_id: str = Form(...),
|
| 443 |
+
session_id: str = Form(...),
|
| 444 |
+
documents: List[UploadFile] = File(...),
|
| 445 |
+
enable_context_integration: bool = Form(True)
|
| 446 |
+
):
|
| 447 |
+
"""๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ + ์ปจํ
์คํธ ํตํฉ"""
|
| 448 |
+
try:
|
| 449 |
+
logger.info(f"๐ ๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ + ์ปจํ
์คํธ ํตํฉ ์์: ์ฌ์ฉ์ {user_id}, ์ธ์
{session_id}, ๋ฌธ์ {len(documents)}๊ฐ")
|
| 450 |
+
|
| 451 |
+
results = []
|
| 452 |
+
|
| 453 |
+
for i, doc in enumerate(documents):
|
| 454 |
+
try:
|
| 455 |
+
# ์์ ํ์ผ๋ก ์ ์ฅ
|
| 456 |
+
temp_path = f"./temp_{user_id}_{session_id}_{i}_{int(time.time())}"
|
| 457 |
+
with open(temp_path, "wb") as f:
|
| 458 |
+
f.write(doc.file.read())
|
| 459 |
+
|
| 460 |
+
# ๋ฌธ์ ID ์์ฑ
|
| 461 |
+
document_id = f"batch_{session_id}_{i}_{int(time.time())}"
|
| 462 |
+
|
| 463 |
+
# RAG ์ฒ๋ฆฌ
|
| 464 |
+
try:
|
| 465 |
+
from lily_llm_core.rag_processor import rag_processor
|
| 466 |
+
rag_result = rag_processor.process_and_store_document(
|
| 467 |
+
user_id=user_id,
|
| 468 |
+
document_id=document_id,
|
| 469 |
+
file_path=temp_path
|
| 470 |
+
)
|
| 471 |
+
except ImportError:
|
| 472 |
+
rag_result = {"success": False, "error": "RAG processor not available"}
|
| 473 |
+
|
| 474 |
+
# ์ปจํ
์คํธ ํตํฉ
|
| 475 |
+
if enable_context_integration and rag_result["success"]:
|
| 476 |
+
try:
|
| 477 |
+
from lily_llm_core.context_manager import context_manager
|
| 478 |
+
context_manager.add_system_message(
|
| 479 |
+
f"๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ ์๋ฃ: {doc.filename} ({rag_result.get('chunks', 0)}๊ฐ ์ฒญํฌ)",
|
| 480 |
+
metadata={"session_id": session_id, "type": "batch_rag", "filename": doc.filename}
|
| 481 |
+
)
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logger.warning(f"โ ๏ธ ์ปจํ
์คํธ ํตํฉ ์คํจ: {e}")
|
| 484 |
+
|
| 485 |
+
# ์์ ํ์ผ ์ ๋ฆฌ
|
| 486 |
+
try:
|
| 487 |
+
import os
|
| 488 |
+
os.remove(temp_path)
|
| 489 |
+
except:
|
| 490 |
+
pass
|
| 491 |
+
|
| 492 |
+
results.append({
|
| 493 |
+
"filename": doc.filename,
|
| 494 |
+
"document_id": document_id,
|
| 495 |
+
"rag_result": rag_result,
|
| 496 |
+
"context_integration": enable_context_integration
|
| 497 |
+
})
|
| 498 |
+
|
| 499 |
+
except Exception as e:
|
| 500 |
+
logger.error(f"โ ๋ฌธ์ {doc.filename} ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 501 |
+
results.append({
|
| 502 |
+
"filename": doc.filename,
|
| 503 |
+
"error": str(e),
|
| 504 |
+
"context_integration": enable_context_integration
|
| 505 |
+
})
|
| 506 |
+
|
| 507 |
+
# ์ฑ๊ณต/์คํจ ํต๊ณ
|
| 508 |
+
success_count = sum(1 for r in results if r.get("rag_result", {}).get("success", False))
|
| 509 |
+
error_count = len(results) - success_count
|
| 510 |
+
|
| 511 |
+
logger.info(f"โ
๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ ์๋ฃ: {success_count}๊ฐ ์ฑ๊ณต, {error_count}๊ฐ ์คํจ")
|
| 512 |
+
|
| 513 |
+
return {
|
| 514 |
+
"status": "success",
|
| 515 |
+
"user_id": user_id,
|
| 516 |
+
"session_id": session_id,
|
| 517 |
+
"total_documents": len(documents),
|
| 518 |
+
"success_count": success_count,
|
| 519 |
+
"error_count": error_count,
|
| 520 |
+
"results": results,
|
| 521 |
+
"context_integration": enable_context_integration
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
except Exception as e:
|
| 525 |
+
logger.error(f"โ ๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ + ์ปจํ
์คํธ ํตํฉ ์คํจ: {e}")
|
| 526 |
+
return {"status": "error", "message": str(e)}
|
| 527 |
+
|
| 528 |
+
@router.get("/rag/context-integrated/search-history/{session_id}")
|
| 529 |
+
async def get_rag_search_history(session_id: str, limit: int = 10):
|
| 530 |
+
"""RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ"""
|
| 531 |
+
try:
|
| 532 |
+
try:
|
| 533 |
+
from lily_llm_core.context_manager import context_manager
|
| 534 |
+
if not context_manager:
|
| 535 |
+
return {"status": "error", "message": "์ปจํ
์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."}
|
| 536 |
+
except ImportError:
|
| 537 |
+
return {"status": "error", "message": "Context manager not available"}
|
| 538 |
+
|
| 539 |
+
# RAG ๊ด๋ จ ๊ฒ์ ํ์คํ ๋ฆฌ ์ถ์ถ
|
| 540 |
+
search_history = []
|
| 541 |
+
if session_id in context_manager.session_conversations:
|
| 542 |
+
for turn in context_manager.session_conversations[session_id]:
|
| 543 |
+
if (hasattr(turn, 'metadata') and turn.metadata and
|
| 544 |
+
turn.metadata.get('type') in ['rag_integration', 'rag_context', 'batch_rag']):
|
| 545 |
+
search_history.append({
|
| 546 |
+
"timestamp": turn.timestamp,
|
| 547 |
+
"type": turn.metadata.get('type'),
|
| 548 |
+
"query": turn.metadata.get('query', ''),
|
| 549 |
+
"filename": turn.metadata.get('filename', ''),
|
| 550 |
+
"content": turn.content
|
| 551 |
+
})
|
| 552 |
+
|
| 553 |
+
# ์ต๊ทผ ์์ผ๋ก ์ ๋ ฌํ๊ณ ์ ํ
|
| 554 |
+
search_history.sort(key=lambda x: x['timestamp'], reverse=True)
|
| 555 |
+
limited_history = search_history[:limit]
|
| 556 |
+
|
| 557 |
+
return {
|
| 558 |
+
"status": "success",
|
| 559 |
+
"session_id": session_id,
|
| 560 |
+
"search_history": limited_history,
|
| 561 |
+
"total_count": len(search_history),
|
| 562 |
+
"limited_count": len(limited_history)
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
except Exception as e:
|
| 566 |
+
logger.error(f"โ RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ ์คํจ: {e}")
|
| 567 |
+
return {"status": "error", "message": str(e)}
|
lily_llm_api/api/routers/ocr_router.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OCR (Image OCR, LaTeX-OCR) router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
import os
|
| 9 |
+
import uuid
|
| 10 |
+
|
| 11 |
+
from ...models.schemas import DocumentUploadResponse, RAGResponse
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
router = APIRouter()
|
| 15 |
+
|
| 16 |
+
# ============================================================================
|
| 17 |
+
# ์ด๋ฏธ์ง OCR ์ ์ฉ API ์๋ํฌ์ธํธ
|
| 18 |
+
# ============================================================================
|
| 19 |
+
|
| 20 |
+
@router.post("/image-ocr/upload", response_model=DocumentUploadResponse)
|
| 21 |
+
async def upload_image_document(
|
| 22 |
+
file: UploadFile = File(...),
|
| 23 |
+
user_id: str = Form("default_user"),
|
| 24 |
+
document_id: Optional[str] = Form(None)
|
| 25 |
+
):
|
| 26 |
+
"""์ด๋ฏธ์ง OCR ์ ์ฉ ๋ฌธ์ ์
๋ก๋"""
|
| 27 |
+
start_time = time.time()
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
# ๋ฌธ์ ID ์์ฑ (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ)
|
| 31 |
+
if not document_id:
|
| 32 |
+
document_id = str(uuid.uuid4())[:8]
|
| 33 |
+
|
| 34 |
+
# ์์ ํ์ผ ์ ์ฅ
|
| 35 |
+
temp_file_path = f"./temp_image_{document_id}_{file.filename}"
|
| 36 |
+
with open(temp_file_path, "wb") as f:
|
| 37 |
+
content = await file.read()
|
| 38 |
+
f.write(content)
|
| 39 |
+
|
| 40 |
+
# ์ด๋ฏธ์ง OCR ์ฒ๋ฆฌ ๋ฐ ๋ฒกํฐ ์คํ ์ด์ ์ ์ฅ
|
| 41 |
+
try:
|
| 42 |
+
from lily_llm_core.image_rag_processor import image_rag_processor
|
| 43 |
+
result = image_rag_processor.process_and_store_image_document(
|
| 44 |
+
user_id, document_id, temp_file_path
|
| 45 |
+
)
|
| 46 |
+
except ImportError:
|
| 47 |
+
result = {
|
| 48 |
+
"success": False,
|
| 49 |
+
"error": "Image RAG processor not available"
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
# ์์ ํ์ผ ์ญ์
|
| 53 |
+
if os.path.exists(temp_file_path):
|
| 54 |
+
os.remove(temp_file_path)
|
| 55 |
+
|
| 56 |
+
processing_time = time.time() - start_time
|
| 57 |
+
logger.info(f"๐ผ๏ธ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์
๋ก๋ ์๋ฃ ({processing_time:.2f}์ด): {file.filename}")
|
| 58 |
+
|
| 59 |
+
return DocumentUploadResponse(
|
| 60 |
+
success=result["success"],
|
| 61 |
+
document_id=document_id,
|
| 62 |
+
message=result.get("message", ""),
|
| 63 |
+
chunks=result.get("chunks"),
|
| 64 |
+
latex_count=result.get("latex_count"),
|
| 65 |
+
error=result.get("error"),
|
| 66 |
+
auto_response=result.get("auto_response", "")
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"โ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์
๋ก๋ ์คํจ: {e}")
|
| 71 |
+
return DocumentUploadResponse(
|
| 72 |
+
success=False,
|
| 73 |
+
document_id=document_id if 'document_id' in locals() else "unknown",
|
| 74 |
+
message="์ด๋ฏธ์ง OCR ๋ฌธ์ ์
๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.",
|
| 75 |
+
error=str(e)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
@router.post("/image-ocr/generate", response_model=RAGResponse)
|
| 79 |
+
async def generate_image_ocr_response(
|
| 80 |
+
query: str = Form(...),
|
| 81 |
+
user_id: str = Form("default_user"),
|
| 82 |
+
document_id: str = Form(...)
|
| 83 |
+
):
|
| 84 |
+
"""์ด๋ฏธ์ง OCR ๊ธฐ๋ฐ RAG ์๋ต ์์ฑ"""
|
| 85 |
+
start_time = time.time()
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# ์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ
|
| 89 |
+
try:
|
| 90 |
+
from lily_llm_core.image_rag_processor import image_rag_processor
|
| 91 |
+
result = image_rag_processor.generate_image_rag_response(
|
| 92 |
+
user_id, document_id, query
|
| 93 |
+
)
|
| 94 |
+
except ImportError:
|
| 95 |
+
result = {
|
| 96 |
+
"success": False,
|
| 97 |
+
"response": "Image RAG processor not available",
|
| 98 |
+
"context": "",
|
| 99 |
+
"sources": [],
|
| 100 |
+
"search_results": 0
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
processing_time = time.time() - start_time
|
| 104 |
+
result["processing_time"] = processing_time
|
| 105 |
+
|
| 106 |
+
logger.info(f"๐ผ๏ธ ์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ ์๋ฃ ({processing_time:.2f}์ด)")
|
| 107 |
+
return result
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"โ ์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ ์คํจ: {e}")
|
| 111 |
+
return RAGResponse(
|
| 112 |
+
success=False,
|
| 113 |
+
response=f"์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}",
|
| 114 |
+
context="",
|
| 115 |
+
sources=[],
|
| 116 |
+
search_results=0,
|
| 117 |
+
processing_time=time.time() - start_time
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
@router.get("/image-ocr/document/{user_id}/{document_id}")
|
| 121 |
+
async def get_image_document_info(user_id: str, document_id: str):
|
| 122 |
+
"""์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ณด ์กฐํ"""
|
| 123 |
+
try:
|
| 124 |
+
try:
|
| 125 |
+
from lily_llm_core.image_rag_processor import image_rag_processor
|
| 126 |
+
result = image_rag_processor.get_image_document_info(user_id, document_id)
|
| 127 |
+
except ImportError:
|
| 128 |
+
result = {
|
| 129 |
+
"success": False,
|
| 130 |
+
"error": "Image RAG processor not available"
|
| 131 |
+
}
|
| 132 |
+
return result
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"โ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ณด ์กฐํ ์คํจ: {e}")
|
| 135 |
+
return {
|
| 136 |
+
"success": False,
|
| 137 |
+
"error": str(e)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
@router.delete("/image-ocr/document/{user_id}/{document_id}")
|
| 141 |
+
async def delete_image_document(user_id: str, document_id: str):
|
| 142 |
+
"""์ด๋ฏธ์ง OCR ๋ฌธ์ ์ญ์ """
|
| 143 |
+
try:
|
| 144 |
+
# ๋ฒกํฐ ์คํ ์ด์์ ๋ฌธ์ ์ญ์
|
| 145 |
+
try:
|
| 146 |
+
from lily_llm_core.vector_store_manager import vector_store_manager
|
| 147 |
+
success = vector_store_manager.delete_document(user_id, document_id)
|
| 148 |
+
except ImportError:
|
| 149 |
+
success = False
|
| 150 |
+
|
| 151 |
+
if success:
|
| 152 |
+
return {
|
| 153 |
+
"success": True,
|
| 154 |
+
"message": "์ด๋ฏธ์ง OCR ๋ฌธ์๊ฐ ์ญ์ ๋์์ต๋๋ค."
|
| 155 |
+
}
|
| 156 |
+
else:
|
| 157 |
+
return {
|
| 158 |
+
"success": False,
|
| 159 |
+
"error": "์ด๋ฏธ์ง OCR ๋ฌธ์ ์ญ์ ์ ์คํจํ์ต๋๋ค."
|
| 160 |
+
}
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"โ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์ญ์ ์คํจ: {e}")
|
| 163 |
+
return {
|
| 164 |
+
"success": False,
|
| 165 |
+
"error": str(e)
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# ============================================================================
|
| 169 |
+
# LaTeX-OCR ์ ์ฉ API ์๋ํฌ์ธํธ
|
| 170 |
+
# ============================================================================
|
| 171 |
+
|
| 172 |
+
@router.post("/latex-ocr/upload", response_model=DocumentUploadResponse)
|
| 173 |
+
async def upload_latex_document(
|
| 174 |
+
file: UploadFile = File(...),
|
| 175 |
+
user_id: str = Form("default_user"),
|
| 176 |
+
document_id: Optional[str] = Form(None)
|
| 177 |
+
):
|
| 178 |
+
"""LaTeX-OCR ์ ์ฉ ๋ฌธ์ ์
๋ก๋"""
|
| 179 |
+
start_time = time.time()
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
# ๋ฌธ์ ID ์์ฑ (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ)
|
| 183 |
+
if not document_id:
|
| 184 |
+
document_id = str(uuid.uuid4())[:8]
|
| 185 |
+
|
| 186 |
+
# ์์ ํ์ผ ์ ์ฅ
|
| 187 |
+
temp_file_path = f"./temp_latex_{document_id}_{file.filename}"
|
| 188 |
+
with open(temp_file_path, "wb") as f:
|
| 189 |
+
content = await file.read()
|
| 190 |
+
f.write(content)
|
| 191 |
+
|
| 192 |
+
# LaTeX-OCR ์ฒ๋ฆฌ ๋ฐ ๋ฒกํฐ ์คํ ์ด์ ์ ์ฅ
|
| 193 |
+
try:
|
| 194 |
+
from lily_llm_core.latex_rag_processor import latex_rag_processor
|
| 195 |
+
result = latex_rag_processor.process_and_store_latex_document(
|
| 196 |
+
user_id, document_id, temp_file_path
|
| 197 |
+
)
|
| 198 |
+
except ImportError:
|
| 199 |
+
result = {
|
| 200 |
+
"success": False,
|
| 201 |
+
"error": "LaTeX RAG processor not available"
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# ์์ ํ์ผ ์ญ์
|
| 205 |
+
if os.path.exists(temp_file_path):
|
| 206 |
+
os.remove(temp_file_path)
|
| 207 |
+
|
| 208 |
+
processing_time = time.time() - start_time
|
| 209 |
+
logger.info(f"๐งฎ LaTeX-OCR ๋ฌธ์ ์
๋ก๋ ์๋ฃ ({processing_time:.2f}์ด): {file.filename}")
|
| 210 |
+
|
| 211 |
+
return DocumentUploadResponse(
|
| 212 |
+
success=result["success"],
|
| 213 |
+
document_id=document_id,
|
| 214 |
+
message=result.get("message", ""),
|
| 215 |
+
chunks=result.get("chunks"),
|
| 216 |
+
latex_count=result.get("latex_count"),
|
| 217 |
+
error=result.get("error"),
|
| 218 |
+
auto_response=result.get("auto_response", "")
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"โ LaTeX-OCR ๋ฌธ์ ์
๋ก๋ ์คํจ: {e}")
|
| 223 |
+
return DocumentUploadResponse(
|
| 224 |
+
success=False,
|
| 225 |
+
document_id=document_id if 'document_id' in locals() else "unknown",
|
| 226 |
+
message="LaTeX-OCR ๋ฌธ์ ์
๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.",
|
| 227 |
+
error=str(e)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
@router.post("/latex-ocr/generate", response_model=RAGResponse)
|
| 231 |
+
async def generate_latex_ocr_response(
|
| 232 |
+
query: str = Form(...),
|
| 233 |
+
user_id: str = Form("default_user"),
|
| 234 |
+
document_id: str = Form(...)
|
| 235 |
+
):
|
| 236 |
+
"""LaTeX-OCR ๊ธฐ๋ฐ RAG ์๋ต ์์ฑ"""
|
| 237 |
+
start_time = time.time()
|
| 238 |
+
|
| 239 |
+
try:
|
| 240 |
+
# LaTeX-OCR RAG ์๋ต ์์ฑ
|
| 241 |
+
try:
|
| 242 |
+
from lily_llm_core.latex_rag_processor import latex_rag_processor
|
| 243 |
+
result = latex_rag_processor.generate_latex_rag_response(
|
| 244 |
+
user_id, document_id, query
|
| 245 |
+
)
|
| 246 |
+
except ImportError:
|
| 247 |
+
result = {
|
| 248 |
+
"success": False,
|
| 249 |
+
"response": "LaTeX RAG processor not available",
|
| 250 |
+
"context": "",
|
| 251 |
+
"sources": [],
|
| 252 |
+
"search_results": 0
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
processing_time = time.time() - start_time
|
| 256 |
+
result["processing_time"] = processing_time
|
| 257 |
+
|
| 258 |
+
logger.info(f"๐งฎ LaTeX-OCR RAG ์๋ต ์์ฑ ์๋ฃ ({processing_time:.2f}์ด)")
|
| 259 |
+
return result
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
logger.error(f"โ LaTeX-OCR RAG ์๋ต ์์ฑ ์คํจ: {e}")
|
| 263 |
+
return RAGResponse(
|
| 264 |
+
success=False,
|
| 265 |
+
response=f"LaTeX-OCR RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}",
|
| 266 |
+
context="",
|
| 267 |
+
sources=[],
|
| 268 |
+
search_results=0,
|
| 269 |
+
processing_time=time.time() - start_time
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
@router.get("/latex-ocr/document/{user_id}/{document_id}")
|
| 273 |
+
async def get_latex_document_info(user_id: str, document_id: str):
|
| 274 |
+
"""LaTeX-OCR ๋ฌธ์ ์ ๋ณด ์กฐํ"""
|
| 275 |
+
try:
|
| 276 |
+
try:
|
| 277 |
+
from lily_llm_core.latex_rag_processor import latex_rag_processor
|
| 278 |
+
result = latex_rag_processor.get_latex_document_info(user_id, document_id)
|
| 279 |
+
except ImportError:
|
| 280 |
+
result = {
|
| 281 |
+
"success": False,
|
| 282 |
+
"error": "LaTeX RAG processor not available"
|
| 283 |
+
}
|
| 284 |
+
return result
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"โ LaTeX-OCR ๋ฌธ์ ์ ๋ณด ์กฐํ ์คํจ: {e}")
|
| 287 |
+
return {
|
| 288 |
+
"success": False,
|
| 289 |
+
"error": str(e)
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
@router.delete("/latex-ocr/document/{user_id}/{document_id}")
|
| 293 |
+
async def delete_latex_document(user_id: str, document_id: str):
|
| 294 |
+
"""LaTeX-OCR ๋ฌธ์ ์ญ์ """
|
| 295 |
+
try:
|
| 296 |
+
# ๋ฒกํฐ ์คํ ์ด์์ ๋ฌธ์ ์ญ์
|
| 297 |
+
try:
|
| 298 |
+
from lily_llm_core.vector_store_manager import vector_store_manager
|
| 299 |
+
success = vector_store_manager.delete_document(user_id, document_id)
|
| 300 |
+
except ImportError:
|
| 301 |
+
success = False
|
| 302 |
+
|
| 303 |
+
if success:
|
| 304 |
+
return {
|
| 305 |
+
"success": True,
|
| 306 |
+
"message": "LaTeX-OCR ๋ฌธ์๊ฐ ์ญ์ ๋์์ต๋๋ค."
|
| 307 |
+
}
|
| 308 |
+
else:
|
| 309 |
+
return {
|
| 310 |
+
"success": False,
|
| 311 |
+
"error": "LaTeX-OCR ๋ฌธ์ ์ญ์ ์ ์คํจํ์ต๋๋ค."
|
| 312 |
+
}
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.error(f"โ LaTeX-OCR ๋ฌธ์ ์ญ์ ์คํจ: {e}")
|
| 315 |
+
return {
|
| 316 |
+
"success": False,
|
| 317 |
+
"error": str(e)
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
# ============================================================================
|
| 321 |
+
# LaTeX-OCR + FAISS ํตํฉ ์์คํ
์๋ํฌ์ธํธ (ํ์ฌ ๋นํ์ฑํ)
|
| 322 |
+
# ============================================================================
|
| 323 |
+
|
| 324 |
+
@router.post("/latex-ocr-faiss/process", response_model=DocumentUploadResponse)
|
| 325 |
+
async def process_pdf_with_latex_faiss(
|
| 326 |
+
file: UploadFile = File(...),
|
| 327 |
+
user_id: str = Form("default_user"),
|
| 328 |
+
system_type: str = Form("simple") # "simple" ๋๋ "integrated"
|
| 329 |
+
):
|
| 330 |
+
"""PDF์์ LaTeX ์์ ์ถ์ถ ๋ฐ FAISS ์ ์ฅ (ํ์ฌ ๋นํ์ฑํ)"""
|
| 331 |
+
try:
|
| 332 |
+
# ํ์ผ ์ ์ฅ
|
| 333 |
+
from pathlib import Path
|
| 334 |
+
upload_dir = Path("uploads/latex_ocr_faiss")
|
| 335 |
+
upload_dir.mkdir(parents=True, exist_ok=True)
|
| 336 |
+
|
| 337 |
+
file_path = upload_dir / f"{user_id}_{file.filename}"
|
| 338 |
+
with open(file_path, "wb") as f:
|
| 339 |
+
content = await file.read()
|
| 340 |
+
f.write(content)
|
| 341 |
+
|
| 342 |
+
# ํ์ฌ ๋นํ์ฑํ๋ ๊ธฐ๋ฅ
|
| 343 |
+
return DocumentUploadResponse(
|
| 344 |
+
success=False,
|
| 345 |
+
document_id="",
|
| 346 |
+
message="LaTeX-OCR + FAISS ๊ธฐ๋ฅ์ด ํ์ฌ ๋นํ์ฑํ๋์ด ์์ต๋๋ค",
|
| 347 |
+
error="์ญ์ ๋ ๋ชจ๋๋ก ์ธํด ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ"
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(f"LaTeX-OCR + FAISS ์ฒ๋ฆฌ ์ค๋ฅ: {e}")
|
| 352 |
+
return DocumentUploadResponse(
|
| 353 |
+
success=False,
|
| 354 |
+
document_id="",
|
| 355 |
+
message="์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค",
|
| 356 |
+
error=f"์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}"
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
@router.post("/latex-ocr-faiss/search", response_model=RAGResponse)
|
| 360 |
+
async def search_latex_formulas(
|
| 361 |
+
query: str = Form(...),
|
| 362 |
+
user_id: str = Form("default_user"),
|
| 363 |
+
document_path: Optional[str] = Form(None),
|
| 364 |
+
system_type: str = Form("simple"),
|
| 365 |
+
k: int = Form(5)
|
| 366 |
+
):
|
| 367 |
+
"""์ ์ฅ๋ LaTeX ์์ ๊ฒ์ (ํ์ฌ ๋นํ์ฑํ)"""
|
| 368 |
+
try:
|
| 369 |
+
# ํ์ฌ ๋นํ์ฑํ๋ ๊ธฐ๋ฅ
|
| 370 |
+
return RAGResponse(
|
| 371 |
+
success=False,
|
| 372 |
+
response="LaTeX-OCR + FAISS ๊ฒ์ ๊ธฐ๋ฅ์ด ํ์ฌ ๋นํ์ฑํ๋์ด ์์ต๋๋ค",
|
| 373 |
+
context="",
|
| 374 |
+
sources=[],
|
| 375 |
+
search_results=0,
|
| 376 |
+
processing_time=0.0,
|
| 377 |
+
error="์ญ์ ๋ ๋ชจ๋๋ก ์ธํด ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ"
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
logger.error(f"LaTeX ์์ ๊ฒ์ ์ค๋ฅ: {e}")
|
| 382 |
+
return RAGResponse(
|
| 383 |
+
success=False,
|
| 384 |
+
response="๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.",
|
| 385 |
+
context="",
|
| 386 |
+
sources=[],
|
| 387 |
+
search_results=0,
|
| 388 |
+
processing_time=0.0,
|
| 389 |
+
error=str(e)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
@router.get("/latex-ocr-faiss/status")
|
| 393 |
+
async def get_latex_ocr_faiss_status():
|
| 394 |
+
"""LaTeX-OCR + FAISS ์์คํ
์ํ ํ์ธ (ํ์ฌ ๋นํ์ฑํ)"""
|
| 395 |
+
try:
|
| 396 |
+
return {
|
| 397 |
+
"simple_system_initialized": False,
|
| 398 |
+
"integrated_system_initialized": False,
|
| 399 |
+
"status": "disabled",
|
| 400 |
+
"message": "LaTeX-OCR + FAISS ๊ธฐ๋ฅ์ด ํ์ฌ ๋นํ์ฑํ๋์ด ์์ต๋๋ค"
|
| 401 |
+
}
|
| 402 |
+
except Exception as e:
|
| 403 |
+
logger.error(f"์ํ ํ์ธ ์ค๋ฅ: {e}")
|
| 404 |
+
return {"status": "error", "error": str(e)}
|
lily_llm_api/api/routers/user_memory_router.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User memory settings management router for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from fastapi import APIRouter, HTTPException, Form
|
| 5 |
+
import logging
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
# ============================================================================
|
| 12 |
+
# ์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์ ๊ด๋ฆฌ API
|
| 13 |
+
# ============================================================================
|
| 14 |
+
|
| 15 |
+
@router.get("/user/memory/settings/{user_id}")
|
| 16 |
+
async def get_user_memory_settings(user_id: str):
|
| 17 |
+
"""์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์ ์กฐํ"""
|
| 18 |
+
try:
|
| 19 |
+
try:
|
| 20 |
+
from lily_llm_core.user_memory_manager import user_memory_manager
|
| 21 |
+
|
| 22 |
+
# ๊ธฐ๋ณธ ์ค์ ์กฐํ
|
| 23 |
+
keep_memory = user_memory_manager.get_memory_setting(user_id, "keep_memory_on_room_change")
|
| 24 |
+
|
| 25 |
+
return {
|
| 26 |
+
"status": "success",
|
| 27 |
+
"user_id": user_id,
|
| 28 |
+
"settings": {
|
| 29 |
+
"keep_memory_on_room_change": keep_memory if keep_memory is not None else True
|
| 30 |
+
}
|
| 31 |
+
}
|
| 32 |
+
except ImportError:
|
| 33 |
+
return {"status": "error", "message": "User memory manager not available"}
|
| 34 |
+
except Exception as e:
|
| 35 |
+
return {"status": "error", "message": str(e)}
|
| 36 |
+
|
| 37 |
+
@router.post("/user/memory/settings/{user_id}")
|
| 38 |
+
async def update_user_memory_settings(
|
| 39 |
+
user_id: str,
|
| 40 |
+
keep_memory_on_room_change: bool = Form(True)
|
| 41 |
+
):
|
| 42 |
+
"""์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์ ์
๋ฐ์ดํธ"""
|
| 43 |
+
try:
|
| 44 |
+
try:
|
| 45 |
+
from lily_llm_core.user_memory_manager import user_memory_manager
|
| 46 |
+
|
| 47 |
+
# ์ค์ ์
๋ฐ์ดํธ
|
| 48 |
+
success = user_memory_manager.update_memory_setting(
|
| 49 |
+
user_id, "keep_memory_on_room_change", keep_memory_on_room_change
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if success:
|
| 53 |
+
return {
|
| 54 |
+
"status": "success",
|
| 55 |
+
"message": f"์ฌ์ฉ์ {user_id} ๋ฉ๋ชจ๋ฆฌ ์ค์ ์
๋ฐ์ดํธ ์๋ฃ",
|
| 56 |
+
"settings": {
|
| 57 |
+
"keep_memory_on_room_change": keep_memory_on_room_change
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
else:
|
| 61 |
+
return {"status": "error", "message": "์ค์ ์
๋ฐ์ดํธ ์คํจ"}
|
| 62 |
+
except ImportError:
|
| 63 |
+
return {"status": "error", "message": "User memory manager not available"}
|
| 64 |
+
except Exception as e:
|
| 65 |
+
return {"status": "error", "message": str(e)}
|
| 66 |
+
|
| 67 |
+
@router.post("/user/memory/room-change/{user_id}")
|
| 68 |
+
async def handle_room_change(user_id: str, new_room_id: str = Form(...)):
|
| 69 |
+
"""Room ๋ณ๊ฒฝ ์ ๋ฉ๋ชจ๋ฆฌ ์ฒ๋ฆฌ"""
|
| 70 |
+
try:
|
| 71 |
+
try:
|
| 72 |
+
from lily_llm_core.user_memory_manager import user_memory_manager
|
| 73 |
+
from lily_llm_core.integrated_memory_manager import integrated_memory_manager
|
| 74 |
+
except ImportError:
|
| 75 |
+
return {"status": "error", "message": "Memory managers not available"}
|
| 76 |
+
|
| 77 |
+
# ์ฌ์ฉ์ ์ค์ ํ์ธ
|
| 78 |
+
keep_memory = user_memory_manager.get_memory_setting(user_id, "keep_memory_on_room_change")
|
| 79 |
+
|
| 80 |
+
if keep_memory:
|
| 81 |
+
# ๋ฉ๋ชจ๋ฆฌ ์ ์ง (๊ธฐ๋ณธ ๋์)
|
| 82 |
+
logger.info(f"๐ ์ฌ์ฉ์ {user_id}๊ฐ room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ ์ง")
|
| 83 |
+
return {
|
| 84 |
+
"status": "success",
|
| 85 |
+
"message": f"Room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ ์ง๋จ",
|
| 86 |
+
"memory_preserved": True
|
| 87 |
+
}
|
| 88 |
+
else:
|
| 89 |
+
# ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ
|
| 90 |
+
logger.info(f"๐ ์ฌ์ฉ์ {user_id}๊ฐ room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ")
|
| 91 |
+
|
| 92 |
+
# ์ธ์
์ปจํ
์คํธ ์ด๊ธฐํ
|
| 93 |
+
try:
|
| 94 |
+
from lily_llm_core.context_manager import context_manager
|
| 95 |
+
if context_manager:
|
| 96 |
+
# ์ฌ์ฉ์ ๊ด๋ จ ์ธ์
๋ค ์ฐพ์์ ์ด๊ธฐํ
|
| 97 |
+
user_sessions = [
|
| 98 |
+
session_id for session_id in context_manager.session_conversations.keys()
|
| 99 |
+
if f"user_{user_id}" in session_id
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
for session_id in user_sessions:
|
| 103 |
+
context_manager.clear_session_context(session_id)
|
| 104 |
+
logger.info(f"๐๏ธ ์ธ์
์ปจํ
์คํธ ์ด๊ธฐํ: {session_id}")
|
| 105 |
+
except ImportError:
|
| 106 |
+
logger.warning("โ ๏ธ Context manager not available for session cleanup")
|
| 107 |
+
|
| 108 |
+
# Room ์ปจํ
์คํธ ์ด๊ธฐํ (์ฌ์ฉ์ ๊ด๋ จ ๋ฌธ์ ์ ๊ฑฐ)
|
| 109 |
+
try:
|
| 110 |
+
room_context = integrated_memory_manager.room_context_manager.get_room_context(new_room_id)
|
| 111 |
+
if room_context and room_context.documents:
|
| 112 |
+
# ์ฌ์ฉ์๊ฐ ์
๋ก๋ํ ๋ฌธ์๋ค ์ ๊ฑฐ
|
| 113 |
+
original_count = len(room_context.documents)
|
| 114 |
+
room_context.documents = [
|
| 115 |
+
doc for doc in room_context.documents
|
| 116 |
+
if (isinstance(doc, dict) and doc.get('uploaded_by') != user_id) or
|
| 117 |
+
(hasattr(doc, 'uploaded_by') and getattr(doc, 'uploaded_by') != user_id)
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
# ๋ณ๊ฒฝ์ฌํญ ์ ์ฅ
|
| 121 |
+
integrated_memory_manager.room_context_manager.save_room_context(new_room_id, room_context)
|
| 122 |
+
|
| 123 |
+
removed_count = original_count - len(room_context.documents)
|
| 124 |
+
logger.info(f"๐๏ธ Room {new_room_id}์์ ์ฌ์ฉ์ {user_id} ๋ฌธ์ {removed_count}๊ฐ ์ ๊ฑฐ")
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.warning(f"โ ๏ธ Room ์ปจํ
์คํธ ์ด๊ธฐํ ์คํจ: {e}")
|
| 127 |
+
|
| 128 |
+
return {
|
| 129 |
+
"status": "success",
|
| 130 |
+
"message": f"Room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ๋จ",
|
| 131 |
+
"memory_preserved": False,
|
| 132 |
+
"context_cleared": True
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"โ Room ๋ณ๊ฒฝ ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 137 |
+
return {"status": "error", "message": str(e)}
|
| 138 |
+
|
| 139 |
+
@router.get("/user/memory/status/{user_id}")
|
| 140 |
+
async def get_user_memory_status(user_id: str):
|
| 141 |
+
"""์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ํ ์กฐํ"""
|
| 142 |
+
try:
|
| 143 |
+
try:
|
| 144 |
+
from lily_llm_core.user_memory_manager import user_memory_manager
|
| 145 |
+
from lily_llm_core.integrated_memory_manager import integrated_memory_manager
|
| 146 |
+
|
| 147 |
+
# ์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์
|
| 148 |
+
memory_settings = {
|
| 149 |
+
"keep_memory_on_room_change": user_memory_manager.get_memory_setting(user_id, "keep_memory_on_room_change")
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# ์ฌ์ฉ์ ๊ด๋ จ ์ธ์
์ ๋ณด
|
| 153 |
+
session_info = {}
|
| 154 |
+
try:
|
| 155 |
+
from lily_llm_core.context_manager import context_manager
|
| 156 |
+
if context_manager:
|
| 157 |
+
user_sessions = [
|
| 158 |
+
session_id for session_id in context_manager.session_conversations.keys()
|
| 159 |
+
if f"user_{user_id}" in session_id
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
for session_id in user_sessions:
|
| 163 |
+
session_info[session_id] = {
|
| 164 |
+
"turns": len(context_manager.session_conversations[session_id]),
|
| 165 |
+
"context_summary": context_manager.get_context_summary(session_id)
|
| 166 |
+
}
|
| 167 |
+
except ImportError:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
# ์ฌ์ฉ์ ๊ด๋ จ ๋ฌธ์ ์ ๋ณด
|
| 171 |
+
document_info = {}
|
| 172 |
+
try:
|
| 173 |
+
# ์ฌ์ฉ์๊ฐ ์
๋ก๋ํ ๋ฌธ์๋ค ์กฐํ
|
| 174 |
+
user_documents = integrated_memory_manager.get_user_documents(user_id)
|
| 175 |
+
document_info = {
|
| 176 |
+
"total_documents": len(user_documents),
|
| 177 |
+
"document_types": list(set(doc.get('type', 'unknown') for doc in user_documents if isinstance(doc, dict))),
|
| 178 |
+
"recent_uploads": sorted(user_documents, key=lambda x: x.get('upload_time', 0), reverse=True)[:5] if user_documents else []
|
| 179 |
+
}
|
| 180 |
+
except Exception as e:
|
| 181 |
+
logger.warning(f"โ ๏ธ ์ฌ์ฉ์ ๋ฌธ์ ์ ๋ณด ์กฐํ ์คํจ: {e}")
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"status": "success",
|
| 185 |
+
"user_id": user_id,
|
| 186 |
+
"memory_settings": memory_settings,
|
| 187 |
+
"session_info": session_info,
|
| 188 |
+
"document_info": document_info,
|
| 189 |
+
"timestamp": time.time()
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
except ImportError:
|
| 193 |
+
return {"status": "error", "message": "Memory managers not available"}
|
| 194 |
+
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"โ ์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ํ ์กฐํ ์คํจ: {e}")
|
| 197 |
+
return {"status": "error", "message": str(e)}
|
| 198 |
+
|
| 199 |
+
@router.post("/user/memory/clear/{user_id}")
|
| 200 |
+
async def clear_user_memory(user_id: str, clear_type: str = Form("all")):
|
| 201 |
+
"""์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ"""
|
| 202 |
+
try:
|
| 203 |
+
try:
|
| 204 |
+
from lily_llm_core.user_memory_manager import user_memory_manager
|
| 205 |
+
from lily_llm_core.integrated_memory_manager import integrated_memory_manager
|
| 206 |
+
except ImportError:
|
| 207 |
+
return {"status": "error", "message": "Memory managers not available"}
|
| 208 |
+
|
| 209 |
+
cleared_items = {}
|
| 210 |
+
|
| 211 |
+
if clear_type in ["all", "sessions"]:
|
| 212 |
+
# ์ธ์
์ปจํ
์คํธ ์ ๋ฆฌ
|
| 213 |
+
try:
|
| 214 |
+
from lily_llm_core.context_manager import context_manager
|
| 215 |
+
if context_manager:
|
| 216 |
+
user_sessions = [
|
| 217 |
+
session_id for session_id in context_manager.session_conversations.keys()
|
| 218 |
+
if f"user_{user_id}" in session_id
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
for session_id in user_sessions:
|
| 222 |
+
context_manager.clear_session_context(session_id)
|
| 223 |
+
|
| 224 |
+
cleared_items["sessions"] = len(user_sessions)
|
| 225 |
+
logger.info(f"๐๏ธ ์ฌ์ฉ์ {user_id} ์ธ์
์ปจํ
์คํธ {len(user_sessions)}๊ฐ ์ ๋ฆฌ ์๋ฃ")
|
| 226 |
+
except ImportError:
|
| 227 |
+
pass
|
| 228 |
+
|
| 229 |
+
if clear_type in ["all", "documents"]:
|
| 230 |
+
# ์ฌ์ฉ์ ๋ฌธ์ ์ ๋ฆฌ
|
| 231 |
+
try:
|
| 232 |
+
user_documents = integrated_memory_manager.get_user_documents(user_id)
|
| 233 |
+
for doc in user_documents:
|
| 234 |
+
if isinstance(doc, dict) and doc.get('document_id'):
|
| 235 |
+
integrated_memory_manager.remove_document(doc['document_id'])
|
| 236 |
+
|
| 237 |
+
cleared_items["documents"] = len(user_documents)
|
| 238 |
+
logger.info(f"๐๏ธ ์ฌ์ฉ์ {user_id} ๋ฌธ์ {len(user_documents)}๊ฐ ์ ๋ฆฌ ์๋ฃ")
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.warning(f"โ ๏ธ ์ฌ์ฉ์ ๋ฌธ์ ์ ๋ฆฌ ์คํจ: {e}")
|
| 241 |
+
|
| 242 |
+
if clear_type in ["all", "settings"]:
|
| 243 |
+
# ๋ฉ๋ชจ๋ฆฌ ์ค์ ์ด๊ธฐํ
|
| 244 |
+
try:
|
| 245 |
+
user_memory_manager.reset_user_settings(user_id)
|
| 246 |
+
cleared_items["settings"] = True
|
| 247 |
+
logger.info(f"๐ ์ฌ์ฉ์ {user_id} ๋ฉ๋ชจ๋ฆฌ ์ค์ ์ด๊ธฐํ ์๋ฃ")
|
| 248 |
+
except Exception as e:
|
| 249 |
+
logger.warning(f"โ ๏ธ ๋ฉ๋ชจ๋ฆฌ ์ค์ ์ด๊ธฐํ ์คํจ: {e}")
|
| 250 |
+
|
| 251 |
+
return {
|
| 252 |
+
"status": "success",
|
| 253 |
+
"message": f"์ฌ์ฉ์ {user_id} ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ์๋ฃ",
|
| 254 |
+
"user_id": user_id,
|
| 255 |
+
"clear_type": clear_type,
|
| 256 |
+
"cleared_items": cleared_items
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
except Exception as e:
|
| 260 |
+
logger.error(f"โ ์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ์คํจ: {e}")
|
| 261 |
+
return {"status": "error", "message": str(e)}
|
| 262 |
+
|
| 263 |
+
@router.get("/user/memory/analytics/{user_id}")
|
| 264 |
+
async def get_user_memory_analytics(user_id: str):
|
| 265 |
+
"""์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๋ถ์"""
|
| 266 |
+
try:
|
| 267 |
+
try:
|
| 268 |
+
from lily_llm_core.user_memory_manager import user_memory_manager
|
| 269 |
+
from lily_llm_core.integrated_memory_manager import integrated_memory_manager
|
| 270 |
+
except ImportError:
|
| 271 |
+
return {"status": "error", "message": "Memory managers not available"}
|
| 272 |
+
|
| 273 |
+
analytics = {
|
| 274 |
+
"user_id": user_id,
|
| 275 |
+
"timestamp": time.time(),
|
| 276 |
+
"memory_usage": {},
|
| 277 |
+
"session_stats": {},
|
| 278 |
+
"document_stats": {}
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
# ์ธ์
ํต๊ณ
|
| 282 |
+
try:
|
| 283 |
+
from lily_llm_core.context_manager import context_manager
|
| 284 |
+
if context_manager:
|
| 285 |
+
user_sessions = [
|
| 286 |
+
session_id for session_id in context_manager.session_conversations.keys()
|
| 287 |
+
if f"user_{user_id}" in session_id
|
| 288 |
+
]
|
| 289 |
+
|
| 290 |
+
total_turns = sum(len(context_manager.session_conversations[session_id]) for session_id in user_sessions)
|
| 291 |
+
total_tokens = sum(
|
| 292 |
+
context_manager._estimate_tokens(
|
| 293 |
+
context_manager.get_context(include_system=False, session_id=session_id)
|
| 294 |
+
) for session_id in user_sessions
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
analytics["session_stats"] = {
|
| 298 |
+
"total_sessions": len(user_sessions),
|
| 299 |
+
"total_turns": total_turns,
|
| 300 |
+
"total_tokens": total_tokens,
|
| 301 |
+
"avg_turns_per_session": total_turns / len(user_sessions) if user_sessions else 0
|
| 302 |
+
}
|
| 303 |
+
except ImportError:
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
# ๋ฌธ์ ํต๊ณ
|
| 307 |
+
try:
|
| 308 |
+
user_documents = integrated_memory_manager.get_user_documents(user_id)
|
| 309 |
+
document_types = {}
|
| 310 |
+
total_size = 0
|
| 311 |
+
|
| 312 |
+
for doc in user_documents:
|
| 313 |
+
if isinstance(doc, dict):
|
| 314 |
+
doc_type = doc.get('type', 'unknown')
|
| 315 |
+
document_types[doc_type] = document_types.get(doc_type, 0) + 1
|
| 316 |
+
total_size += doc.get('size', 0)
|
| 317 |
+
|
| 318 |
+
analytics["document_stats"] = {
|
| 319 |
+
"total_documents": len(user_documents),
|
| 320 |
+
"document_types": document_types,
|
| 321 |
+
"total_size_bytes": total_size,
|
| 322 |
+
"total_size_mb": total_size / (1024 * 1024)
|
| 323 |
+
}
|
| 324 |
+
except Exception as e:
|
| 325 |
+
logger.warning(f"โ ๏ธ ๋ฌธ์ ํต๊ณ ์กฐํ ์คํจ: {e}")
|
| 326 |
+
|
| 327 |
+
# ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์์ฝ
|
| 328 |
+
analytics["memory_usage"] = {
|
| 329 |
+
"session_memory_mb": analytics["session_stats"].get("total_tokens", 0) * 4 / (1024 * 1024), # ํ ํฐ๋น ์ฝ 4๋ฐ์ดํธ ์ถ์
|
| 330 |
+
"document_memory_mb": analytics["document_stats"].get("total_size_mb", 0),
|
| 331 |
+
"total_memory_mb": (analytics["session_stats"].get("total_tokens", 0) * 4 / (1024 * 1024)) + analytics["document_stats"].get("total_size_mb", 0)
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
return {
|
| 335 |
+
"status": "success",
|
| 336 |
+
"analytics": analytics
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
except Exception as e:
|
| 340 |
+
logger.error(f"โ ์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ๋ถ์ ์คํจ: {e}")
|
| 341 |
+
return {"status": "error", "message": str(e)}
|
lily_llm_api/app.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
lily_llm_api/app_v2.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
lily_llm_api/app_v2_modular.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Lily LLM API ์๋ฒ v2 - ๋ชจ๋ํ๋ ๋ฒ์
|
| 4 |
+
"""
|
| 5 |
+
import uvicorn
|
| 6 |
+
import logging
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# ๐ RoPE ๊ฒฝ๊ณ ์จ๊ธฐ๊ธฐ
|
| 10 |
+
warnings.filterwarnings("ignore", message="The attention layers in this model are transitioning")
|
| 11 |
+
warnings.filterwarnings("ignore", message="rotary_pos_emb will be removed")
|
| 12 |
+
warnings.filterwarnings("ignore", message="position_embeddings will be mandatory")
|
| 13 |
+
|
| 14 |
+
# logging ์ค์
|
| 15 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
from .core.app_factory import create_app
|
| 19 |
+
|
| 20 |
+
# FastAPI ์ ํ๋ฆฌ์ผ์ด์
์์ฑ
|
| 21 |
+
app = create_app()
|
| 22 |
+
|
| 23 |
+
def run_server():
|
| 24 |
+
"""์๋ฒ ์คํ"""
|
| 25 |
+
uvicorn.run(
|
| 26 |
+
"app_v2_modular:app",
|
| 27 |
+
host="0.0.0.0",
|
| 28 |
+
port=8000,
|
| 29 |
+
reload=False,
|
| 30 |
+
workers=1
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
run_server()
|
lily_llm_api/core/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Core package for Lily LLM API
|
| 3 |
+
"""
|
lily_llm_api/core/app_factory.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI app factory for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import warnings
|
| 6 |
+
from contextlib import asynccontextmanager
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# ๐ RoPE ๊ฒฝ๊ณ ์จ๊ธฐ๊ธฐ (Kanana ๋ชจ๋ธ ๋ด๋ถ ๊ตฌํ ๊ด๋ จ)
|
| 13 |
+
warnings.filterwarnings("ignore", message="The attention layers in this model are transitioning")
|
| 14 |
+
warnings.filterwarnings("ignore", message="rotary_pos_emb will be removed")
|
| 15 |
+
warnings.filterwarnings("ignore", message="position_embeddings will be mandatory")
|
| 16 |
+
|
| 17 |
+
@asynccontextmanager
|
| 18 |
+
async def create_lifespan_handler(app):
|
| 19 |
+
"""์๋ฒ ์๋ช
์ฃผ๊ธฐ ๊ด๋ฆฌ ํธ๋ค๋ฌ ์์ฑ"""
|
| 20 |
+
# ์๋ฒ ์์ ์
|
| 21 |
+
logger.info("๐ ์๋ฒ ์์ ์ด๋ฒคํธ ์คํ ์ค...")
|
| 22 |
+
|
| 23 |
+
# CPU ์ค๋ ๋ ์ต์ ํ ์ ์ฉ
|
| 24 |
+
try:
|
| 25 |
+
from ..utils.system_utils import configure_cpu_threads, select_model_interactive
|
| 26 |
+
configure_cpu_threads()
|
| 27 |
+
logger.info("โ
CPU ์ค๋ ๋ ์ต์ ํ ์๋ฃ")
|
| 28 |
+
except Exception as e:
|
| 29 |
+
logger.error(f"โ CPU ์ค๋ ๋ ์ค์ ์คํจ: {e}")
|
| 30 |
+
|
| 31 |
+
# ๐ ๋ชจ๋ธ ์ ํ ๋ณต์: ์ฌ์ฉ์๊ฐ ๋ชจ๋ธ์ ์ ํํ ์ ์๋๋ก
|
| 32 |
+
try:
|
| 33 |
+
selected_model_id = select_model_interactive()
|
| 34 |
+
logger.info(f"๐ ์๋ฒ ์์ ์ ์ ํ๋ ๋ชจ๋ธ: {selected_model_id}")
|
| 35 |
+
|
| 36 |
+
from ..services.model_service import load_model_async
|
| 37 |
+
await load_model_async(selected_model_id)
|
| 38 |
+
|
| 39 |
+
from ..services.model_service import is_model_loaded
|
| 40 |
+
model_loaded = is_model_loaded()
|
| 41 |
+
logger.info(f"โ
์๋ฒ๊ฐ ๋ชจ๋ธ๋ก ์ค๋น๋์์ต๋๋ค.")
|
| 42 |
+
logger.info(f"โ
model_loaded ์ํ: {model_loaded}")
|
| 43 |
+
|
| 44 |
+
# ๐ ์ค๋ฌด์ฉ: ๊ณ ๊ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์ ์ค์
|
| 45 |
+
try:
|
| 46 |
+
from lily_llm_core.context_manager import context_manager
|
| 47 |
+
# ์์ฝ ๋ฐฉ๋ฒ์ smart๋ก ์ค์ (๊ฐ์ฅ ๊ท ํ์กํ ์์ฝ)
|
| 48 |
+
context_manager.set_summary_method("smart")
|
| 49 |
+
logger.info("โ
๊ณ ๊ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์ ์ค์ ์๋ฃ: smart ์์ฝ ๋ฐฉ๋ฒ ํ์ฑํ")
|
| 50 |
+
|
| 51 |
+
# ์๋ ์ ๋ฆฌ ์ค์ ์ต์ ํ
|
| 52 |
+
context_manager.set_auto_cleanup_config(
|
| 53 |
+
enabled=True,
|
| 54 |
+
interval_turns=5, # 5ํด๋ง๋ค ์ ๋ฆฌ
|
| 55 |
+
interval_time=180, # 3๋ถ๋ง๋ค ์ ๋ฆฌ
|
| 56 |
+
strategy="aggressive" # ์ ๊ทน์ ์ ๋ฆฌ๋ก ๋ฉ๋ชจ๋ฆฌ ์ต์ ํ
|
| 57 |
+
)
|
| 58 |
+
logger.info("โ
์๋ ์ ๋ฆฌ ์ค์ ์ต์ ํ ์๋ฃ")
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning(f"โ ๏ธ ๊ณ ๊ธ ์ปจํ
์คํธ ๊ด๋ฆฌ์ ์ค์ ์คํจ: {e}")
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"โ ๋ชจ๋ธ ๋ก๋์ ์คํจํ์ต๋๋ค: {e}", exc_info=True)
|
| 65 |
+
|
| 66 |
+
logger.info("โ
์๋ฒ ์์ ์ด๋ฒคํธ ์๋ฃ")
|
| 67 |
+
|
| 68 |
+
yield # ์๋ฒ ์คํ ์ค
|
| 69 |
+
|
| 70 |
+
# ์๋ฒ ์ข
๋ฃ ์
|
| 71 |
+
logger.info("๐ ์๋ฒ ์ข
๋ฃ ์ด๋ฒคํธ ์คํ ์ค...")
|
| 72 |
+
|
| 73 |
+
# ์ค๋ ๋ ํ ์คํ๊ธฐ ์ข
๋ฃ
|
| 74 |
+
try:
|
| 75 |
+
from ..services.model_service import shutdown_executor
|
| 76 |
+
shutdown_executor()
|
| 77 |
+
logger.info("โ
์ค๋ ๋ ํ ์คํ๊ธฐ ์ข
๋ฃ ์๋ฃ")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.warning(f"โ ๏ธ ์ค๋ ๋ ํ ์คํ๊ธฐ ์ข
๋ฃ ์คํจ: {e}")
|
| 80 |
+
|
| 81 |
+
logger.info("โ
์๋ฒ ์ข
๋ฃ ์ด๋ฒคํธ ์๋ฃ")
|
| 82 |
+
|
| 83 |
+
def create_app() -> FastAPI:
|
| 84 |
+
"""FastAPI ์ฑ ์์ฑ"""
|
| 85 |
+
# FastAPI ์ฑ ์์ฑ
|
| 86 |
+
app = FastAPI(
|
| 87 |
+
title="Lily LLM API v2",
|
| 88 |
+
description="๋ค์ค ๋ชจ๋ธ ์ง์ LLM API ์๋ฒ",
|
| 89 |
+
version="2.0.0",
|
| 90 |
+
lifespan=create_lifespan_handler
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# CORS ์ค์
|
| 94 |
+
app.add_middleware(
|
| 95 |
+
CORSMiddleware,
|
| 96 |
+
allow_origins=[
|
| 97 |
+
"http://localhost:8001",
|
| 98 |
+
"http://127.0.0.1:8001",
|
| 99 |
+
"http://localhost:3000",
|
| 100 |
+
"http://127.0.0.1:3000",
|
| 101 |
+
"*" # ๊ฐ๋ฐ ์ค์๋ ๋ชจ๋ origin ํ์ฉ
|
| 102 |
+
],
|
| 103 |
+
allow_credentials=True,
|
| 104 |
+
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
| 105 |
+
allow_headers=["*"],
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# ๋ผ์ฐํฐ ๋ฑ๋ก
|
| 109 |
+
from ..api.routers import (
|
| 110 |
+
model_router, generation_router, lora_router, context_router,
|
| 111 |
+
document_router, ocr_router, advanced_context_router,
|
| 112 |
+
multimodal_rag_router, user_memory_router
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
app.include_router(model_router.router, prefix="/api/v2", tags=["models"])
|
| 116 |
+
app.include_router(generation_router.router, prefix="/api/v2", tags=["generation"])
|
| 117 |
+
app.include_router(lora_router.router, prefix="/api/v2", tags=["lora"])
|
| 118 |
+
app.include_router(context_router.router, prefix="/api/v2", tags=["context"])
|
| 119 |
+
app.include_router(document_router.router, prefix="/api/v2", tags=["document"])
|
| 120 |
+
app.include_router(ocr_router.router, prefix="/api/v2", tags=["ocr"])
|
| 121 |
+
app.include_router(advanced_context_router.router, prefix="/api/v2", tags=["advanced-context"])
|
| 122 |
+
app.include_router(multimodal_rag_router.router, prefix="/api/v2", tags=["multimodal-rag"])
|
| 123 |
+
app.include_router(user_memory_router.router, prefix="/api/v2", tags=["user-memory"])
|
| 124 |
+
|
| 125 |
+
return app
|
lily_llm_api/models/back/configuration.py
DELETED
|
@@ -1,125 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
|
| 3 |
-
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
-
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 5 |
-
from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
| 6 |
-
|
| 7 |
-
logger = logging.getLogger("kanana-1.5-v")
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class KananaVVisionConfig(PretrainedConfig):
|
| 11 |
-
model_type = "kanana-1.5-v-visual-encoder"
|
| 12 |
-
base_config_key = "vision_config"
|
| 13 |
-
|
| 14 |
-
def __init__(
|
| 15 |
-
self,
|
| 16 |
-
depth=32,
|
| 17 |
-
embed_dim=1280,
|
| 18 |
-
mlp_ratio=4,
|
| 19 |
-
num_heads=16,
|
| 20 |
-
in_chans=3,
|
| 21 |
-
hidden_size=1280,
|
| 22 |
-
patch_size=14,
|
| 23 |
-
spatial_merge_size=2,
|
| 24 |
-
spatial_patch_size=14,
|
| 25 |
-
temporal_patch_size=2,
|
| 26 |
-
initializer_range=0.02,
|
| 27 |
-
image_size="dynamic",
|
| 28 |
-
image_mean=OPENAI_CLIP_MEAN,
|
| 29 |
-
image_std=OPENAI_CLIP_STD,
|
| 30 |
-
**kwargs,
|
| 31 |
-
):
|
| 32 |
-
super().__init__(**kwargs)
|
| 33 |
-
|
| 34 |
-
self.depth = depth
|
| 35 |
-
self.embed_dim = embed_dim
|
| 36 |
-
self.mlp_ratio = mlp_ratio
|
| 37 |
-
self.num_heads = num_heads
|
| 38 |
-
self.in_chans = in_chans
|
| 39 |
-
self.hidden_size = hidden_size
|
| 40 |
-
self.patch_size = patch_size
|
| 41 |
-
self.spatial_merge_size = spatial_merge_size
|
| 42 |
-
self.spatial_patch_size = spatial_patch_size
|
| 43 |
-
self.temporal_patch_size = temporal_patch_size
|
| 44 |
-
self.initializer_range = initializer_range
|
| 45 |
-
self.image_size = image_size
|
| 46 |
-
self.image_mean = image_mean
|
| 47 |
-
self.image_std = image_std
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class KananaVVisualProjectorConfig(PretrainedConfig):
|
| 51 |
-
model_type = "kanana-1.5-v-visual_projector"
|
| 52 |
-
base_config_key = "projector_config"
|
| 53 |
-
|
| 54 |
-
def __init__(
|
| 55 |
-
self,
|
| 56 |
-
depth=2,
|
| 57 |
-
encoder_hidden_size=1280,
|
| 58 |
-
feature_layer_index=-1,
|
| 59 |
-
hidden_size=1024,
|
| 60 |
-
merge_size=2,
|
| 61 |
-
mlp_depth=2,
|
| 62 |
-
num_eos_tokens=0,
|
| 63 |
-
output_hidden_size=2048,
|
| 64 |
-
pos_emb=True,
|
| 65 |
-
pos_emb_size=576,
|
| 66 |
-
prenorm=False,
|
| 67 |
-
projector_type="dynamic-c-abs",
|
| 68 |
-
**kwargs,
|
| 69 |
-
):
|
| 70 |
-
super().__init__(**kwargs)
|
| 71 |
-
|
| 72 |
-
self.depth = depth
|
| 73 |
-
self.encoder_hidden_size = encoder_hidden_size
|
| 74 |
-
self.feature_layer_index = feature_layer_index
|
| 75 |
-
self.hidden_size = hidden_size
|
| 76 |
-
self.merge_size = merge_size
|
| 77 |
-
self.mlp_depth = mlp_depth
|
| 78 |
-
self.num_eos_tokens = num_eos_tokens
|
| 79 |
-
self.output_hidden_size = output_hidden_size
|
| 80 |
-
self.pos_emb = pos_emb
|
| 81 |
-
self.pos_emb_size = pos_emb_size
|
| 82 |
-
self.prenorm = prenorm
|
| 83 |
-
self.projector_type = projector_type
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
class KananaLanguageConfig(LlamaConfig):
|
| 87 |
-
model_type = "kanana-1.5-3b-instruct"
|
| 88 |
-
base_config_key = "text_config"
|
| 89 |
-
|
| 90 |
-
def __init__(
|
| 91 |
-
self,
|
| 92 |
-
**kwargs,
|
| 93 |
-
):
|
| 94 |
-
super().__init__(**kwargs)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
class KananaVConfig(PretrainedConfig):
|
| 98 |
-
model_type = "kanana-1.5-v"
|
| 99 |
-
is_composition = True
|
| 100 |
-
|
| 101 |
-
def __init__(
|
| 102 |
-
self,
|
| 103 |
-
vision_config: dict = {},
|
| 104 |
-
projector_config: dict = {},
|
| 105 |
-
text_config: dict = {},
|
| 106 |
-
**kwargs,
|
| 107 |
-
):
|
| 108 |
-
super().__init__(**kwargs)
|
| 109 |
-
|
| 110 |
-
# Vision config
|
| 111 |
-
self.vision_config = KananaVVisionConfig(**vision_config)
|
| 112 |
-
|
| 113 |
-
# Visual projector config
|
| 114 |
-
self.projector_config = KananaVVisualProjectorConfig(**projector_config)
|
| 115 |
-
|
| 116 |
-
# Language model config
|
| 117 |
-
self.text_config = KananaLanguageConfig(**text_config)
|
| 118 |
-
|
| 119 |
-
@property
|
| 120 |
-
def num_visual_tokens(self):
|
| 121 |
-
return "dynamic"
|
| 122 |
-
|
| 123 |
-
@property
|
| 124 |
-
def hidden_size(self):
|
| 125 |
-
return self.text_config.hidden_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lily_llm_api/models/back/modeling.py
DELETED
|
@@ -1,973 +0,0 @@
|
|
| 1 |
-
from functools import partial
|
| 2 |
-
import logging
|
| 3 |
-
import re
|
| 4 |
-
from typing import Optional, Tuple, Union, List
|
| 5 |
-
|
| 6 |
-
from einops import rearrange
|
| 7 |
-
from timm.layers import LayerNorm, LayerNorm2d
|
| 8 |
-
from timm.layers.pos_embed import resample_abs_pos_embed
|
| 9 |
-
from timm.models.regnet import RegStage
|
| 10 |
-
import torch
|
| 11 |
-
from torch import nn
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
import torch.utils.checkpoint
|
| 14 |
-
from transformers import LlamaForCausalLM
|
| 15 |
-
from transformers.modeling_outputs import BaseModelOutput
|
| 16 |
-
from transformers.modeling_utils import PreTrainedModel
|
| 17 |
-
from transformers.models.auto import AutoModelForCausalLM
|
| 18 |
-
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
| 19 |
-
Qwen2VLVisionConfig,
|
| 20 |
-
)
|
| 21 |
-
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
| 22 |
-
PatchEmbed,
|
| 23 |
-
Qwen2VLPreTrainedModel,
|
| 24 |
-
Qwen2VisionTransformerPretrainedModel,
|
| 25 |
-
Qwen2VLVisionBlock,
|
| 26 |
-
VisionRotaryEmbedding
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
from configuration import KananaVVisualProjectorConfig, KananaVConfig
|
| 30 |
-
|
| 31 |
-
logger = logging.getLogger("kanana-1.5-v")
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def build_pos_embeds(
|
| 35 |
-
config: KananaVVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int
|
| 36 |
-
):
|
| 37 |
-
# pos emb
|
| 38 |
-
if config.pos_emb:
|
| 39 |
-
# โจ ์์ : num_input_tokens๊ฐ ์์์ผ ๋ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
|
| 40 |
-
if num_input_tokens <= 0:
|
| 41 |
-
num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576
|
| 42 |
-
pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
|
| 43 |
-
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
|
| 44 |
-
else:
|
| 45 |
-
pos_emb = None
|
| 46 |
-
|
| 47 |
-
return pos_emb
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def build_eos_tokens(config: KananaVVisualProjectorConfig, output_hidden_size: int):
|
| 51 |
-
# think tokens
|
| 52 |
-
num_eos_tokens = config.num_eos_tokens
|
| 53 |
-
if num_eos_tokens:
|
| 54 |
-
eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size))
|
| 55 |
-
nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range)
|
| 56 |
-
else:
|
| 57 |
-
eos_tokens = None
|
| 58 |
-
|
| 59 |
-
return eos_tokens
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def build_prenorm(config: KananaVVisualProjectorConfig):
|
| 63 |
-
if getattr(config, "prenorm", False):
|
| 64 |
-
prenorm = LayerNorm(config.encoder_hidden_size)
|
| 65 |
-
else:
|
| 66 |
-
prenorm = None
|
| 67 |
-
return prenorm
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def build_mlp(depth: int, hidden_size: int, output_hidden_size: int):
|
| 71 |
-
layers = [nn.Linear(hidden_size, output_hidden_size)]
|
| 72 |
-
for _ in range(1, depth):
|
| 73 |
-
layers.append(nn.SiLU())
|
| 74 |
-
layers.append(nn.Linear(output_hidden_size, output_hidden_size))
|
| 75 |
-
return nn.Sequential(*layers)
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
class PatchMerge(nn.Module):
|
| 79 |
-
def __init__(self, merge_size):
|
| 80 |
-
super().__init__()
|
| 81 |
-
self.merge_size = merge_size
|
| 82 |
-
|
| 83 |
-
def forward(self, x, channel_last=False):
|
| 84 |
-
if channel_last:
|
| 85 |
-
x = rearrange(x, "B H W D -> B D H W")
|
| 86 |
-
_, D, H, W = x.shape
|
| 87 |
-
|
| 88 |
-
# ํ์ ์ฐจ์์ ์ฒ๋ฆฌํ๊ธฐ ์ํด ํจ๋ฉ ์ถ๊ฐ
|
| 89 |
-
pad_h = (self.merge_size - H % self.merge_size) % self.merge_size
|
| 90 |
-
pad_w = (self.merge_size - W % self.merge_size) % self.merge_size
|
| 91 |
-
|
| 92 |
-
if pad_h > 0 or pad_w > 0:
|
| 93 |
-
print(f"๐ PatchMerge - ํจ๋ฉ ์ถ๊ฐ: H={H}->{H+pad_h}, W={W}->{W+pad_w}")
|
| 94 |
-
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='replicate')
|
| 95 |
-
H, W = H + pad_h, W + pad_w
|
| 96 |
-
|
| 97 |
-
merged_x = rearrange(
|
| 98 |
-
x, "B D (H h2) (W w2) -> B (D h2 w2) H W", h2=self.merge_size, w2=self.merge_size
|
| 99 |
-
)
|
| 100 |
-
return merged_x
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
class DynamicCAbstractor(nn.Module):
|
| 104 |
-
"""Dynamic C-Abstractor based on RegBlock"""
|
| 105 |
-
|
| 106 |
-
def __init__(self, config: KananaVVisualProjectorConfig, num_input_tokens: int):
|
| 107 |
-
super().__init__()
|
| 108 |
-
self.config = config
|
| 109 |
-
|
| 110 |
-
# โจ ์์ : num_input_tokens๊ฐ ์์์ผ ๋ ๊ธฐ๋ณธ๊ฐ ์ค์
|
| 111 |
-
if num_input_tokens <= 0:
|
| 112 |
-
num_input_tokens = config.pos_emb_size if hasattr(config, 'pos_emb_size') else 576
|
| 113 |
-
self.num_input_tokens = num_input_tokens
|
| 114 |
-
|
| 115 |
-
# โจ ์ถ๊ฐ: ๋๋ฝ๋ ์์ฑ๋ค ์ค์
|
| 116 |
-
self.merge_size = getattr(config, 'merge_size', 2)
|
| 117 |
-
self.pos_emb_size = getattr(config, 'pos_emb_size', 576)
|
| 118 |
-
|
| 119 |
-
# โจ ์ต์ ํ: ๋ชจ๋ ๋ ์ด์ด๋ฅผ bfloat16์ผ๋ก ์ด๊ธฐํ
|
| 120 |
-
self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size)
|
| 121 |
-
if self.pos_emb is not None:
|
| 122 |
-
self.pos_emb.data = self.pos_emb.data.to(torch.bfloat16)
|
| 123 |
-
|
| 124 |
-
self.eos_tokens = build_eos_tokens(config, config.output_hidden_size)
|
| 125 |
-
if self.eos_tokens is not None:
|
| 126 |
-
self.eos_tokens.data = self.eos_tokens.data.to(torch.bfloat16)
|
| 127 |
-
|
| 128 |
-
self.prenorm = build_prenorm(config)
|
| 129 |
-
if self.prenorm is not None:
|
| 130 |
-
self.prenorm = self.prenorm.to(torch.bfloat16)
|
| 131 |
-
|
| 132 |
-
# โจ ์์ : build_net์์ self.net๊ณผ self.readout ์ค์
|
| 133 |
-
self.build_net()
|
| 134 |
-
|
| 135 |
-
# โจ ์ต์ ํ: net ๋ ์ด์ด๋ค์ bfloat16์ผ๋ก ๋ณํ
|
| 136 |
-
if hasattr(self, 'net'):
|
| 137 |
-
if isinstance(self.net, nn.ModuleList):
|
| 138 |
-
for layer in self.net:
|
| 139 |
-
layer = layer.to(torch.bfloat16)
|
| 140 |
-
for module in layer.modules():
|
| 141 |
-
if hasattr(module, 'weight'):
|
| 142 |
-
module.weight.data = module.weight.data.to(torch.bfloat16)
|
| 143 |
-
if hasattr(module, 'bias') and module.bias is not None:
|
| 144 |
-
module.bias.data = module.bias.data.to(torch.bfloat16)
|
| 145 |
-
else:
|
| 146 |
-
# self.net์ด ๋จ์ผ ๋ชจ๋์ธ ๊ฒฝ์ฐ
|
| 147 |
-
self.net = self.net.to(torch.bfloat16)
|
| 148 |
-
for module in self.net.modules():
|
| 149 |
-
if hasattr(module, 'weight'):
|
| 150 |
-
module.weight.data = module.weight.data.to(torch.bfloat16)
|
| 151 |
-
if hasattr(module, 'bias') and module.bias is not None:
|
| 152 |
-
module.bias.data = module.bias.data.to(torch.bfloat16)
|
| 153 |
-
|
| 154 |
-
# โจ ์ต์ ํ: readout ๋ ์ด์ด๋ฅผ bfloat16์ผ๋ก ๋ณํ
|
| 155 |
-
if hasattr(self, 'readout'):
|
| 156 |
-
self.readout = self.readout.to(torch.bfloat16)
|
| 157 |
-
for module in self.readout.modules():
|
| 158 |
-
if hasattr(module, 'weight'):
|
| 159 |
-
module.weight.data = module.weight.data.to(torch.bfloat16)
|
| 160 |
-
if hasattr(module, 'bias') and module.bias is not None:
|
| 161 |
-
module.bias.data = module.bias.data.to(torch.bfloat16)
|
| 162 |
-
|
| 163 |
-
def build_net(self):
|
| 164 |
-
encoder_hidden_size = self.config.encoder_hidden_size
|
| 165 |
-
hidden_size = self.config.hidden_size
|
| 166 |
-
output_hidden_size = self.config.output_hidden_size
|
| 167 |
-
depth = self.config.depth
|
| 168 |
-
mlp_depth = self.config.mlp_depth
|
| 169 |
-
|
| 170 |
-
RegBlock = partial(
|
| 171 |
-
RegStage,
|
| 172 |
-
stride=1,
|
| 173 |
-
dilation=1,
|
| 174 |
-
act_layer=nn.SiLU,
|
| 175 |
-
norm_layer=LayerNorm2d,
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
s1 = RegBlock(
|
| 179 |
-
depth,
|
| 180 |
-
encoder_hidden_size,
|
| 181 |
-
hidden_size,
|
| 182 |
-
)
|
| 183 |
-
sampler = PatchMerge(merge_size=self.merge_size)
|
| 184 |
-
s2 = RegBlock(
|
| 185 |
-
depth,
|
| 186 |
-
self.merge_size**2 * hidden_size,
|
| 187 |
-
hidden_size,
|
| 188 |
-
)
|
| 189 |
-
|
| 190 |
-
if depth:
|
| 191 |
-
self.net = nn.ModuleList([s1, sampler, s2])
|
| 192 |
-
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
|
| 193 |
-
else:
|
| 194 |
-
self.net = sampler
|
| 195 |
-
self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)
|
| 196 |
-
|
| 197 |
-
def forward(self, flattened_visual_embeds, grid_thw, **unused_kwargs):
|
| 198 |
-
n_token_loc = torch.prod(grid_thw, dim=1)
|
| 199 |
-
split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())
|
| 200 |
-
|
| 201 |
-
flattened_visual_embeds = []
|
| 202 |
-
for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
|
| 203 |
-
T, H, W = _grid_thw
|
| 204 |
-
assert T == 1, "T must be 1. Video is not supported yet."
|
| 205 |
-
reshaped_visual_embeds = rearrange(
|
| 206 |
-
_visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
|
| 207 |
-
)
|
| 208 |
-
# remove temporal dim
|
| 209 |
-
reshaped_visual_embeds = reshaped_visual_embeds[:, 0]
|
| 210 |
-
|
| 211 |
-
if self.prenorm is not None:
|
| 212 |
-
reshaped_visual_embeds = self.prenorm(reshaped_visual_embeds)
|
| 213 |
-
|
| 214 |
-
if self.pos_emb is not None:
|
| 215 |
-
# interpolate pos emb and add to visual embeds
|
| 216 |
-
print(f"๐ abstractor - pos_emb ํํ: {self.pos_emb.shape}")
|
| 217 |
-
print(f"๐ abstractor - reshaped_visual_embeds ํํ: {reshaped_visual_embeds.shape}")
|
| 218 |
-
|
| 219 |
-
_local_pos_emb = resample_abs_pos_embed(
|
| 220 |
-
posemb=self.pos_emb,
|
| 221 |
-
old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
|
| 222 |
-
new_size=(H, W),
|
| 223 |
-
num_prefix_tokens=0,
|
| 224 |
-
)
|
| 225 |
-
_local_pos_emb = rearrange(
|
| 226 |
-
_local_pos_emb,
|
| 227 |
-
"1 (h w) d -> 1 h w d",
|
| 228 |
-
h=H,
|
| 229 |
-
w=W,
|
| 230 |
-
)
|
| 231 |
-
print(f"๐ abstractor - _local_pos_emb ํํ: {_local_pos_emb.shape}")
|
| 232 |
-
|
| 233 |
-
# ์ฐจ์์ด ๋ง์ง ์๋ ๊ฒฝ์ฐ ์ฒ๋ฆฌ
|
| 234 |
-
if reshaped_visual_embeds.shape[-1] != _local_pos_emb.shape[-1]:
|
| 235 |
-
print(f"๐ abstractor - ์ฐจ์ ๋ถ์ผ์น ๊ฐ์ง, pos_emb ๊ฑด๋๋ฐ๊ธฐ")
|
| 236 |
-
# pos_emb๋ฅผ ๊ฑด๋๋ฐ๊ณ visual_embeds๋ง ์ฌ์ฉ
|
| 237 |
-
else:
|
| 238 |
-
reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb
|
| 239 |
-
|
| 240 |
-
reshaped_visual_embeds = self._forward(
|
| 241 |
-
reshaped_visual_embeds,
|
| 242 |
-
input_size=(H, W),
|
| 243 |
-
)
|
| 244 |
-
flattened_visual_embeds.append(reshaped_visual_embeds)
|
| 245 |
-
reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
|
| 246 |
-
output = BaseModelOutput(last_hidden_state=reshaped_visual_embeds)
|
| 247 |
-
return output
|
| 248 |
-
|
| 249 |
-
def _forward(self, x, input_size):
|
| 250 |
-
h, w = input_size
|
| 251 |
-
|
| 252 |
-
x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
|
| 253 |
-
|
| 254 |
-
# ์
๋ ฅ ์ฑ๋ ์๊ฐ ๋ง์ง ์๋ ๊ฒฝ์ฐ ์ฒ๋ฆฌ
|
| 255 |
-
# RegStage์ ์ฒซ ๋ฒ์งธ ๋ธ๋ก์์ ์ฑ๋ ์ ํ์ธ
|
| 256 |
-
try:
|
| 257 |
-
if hasattr(self.net[0], 'conv'):
|
| 258 |
-
expected_channels = self.net[0].conv.in_channels
|
| 259 |
-
elif hasattr(self.net[0], 'blocks') and len(self.net[0].blocks) > 0:
|
| 260 |
-
expected_channels = self.net[0].blocks[0].conv1.in_channels
|
| 261 |
-
else:
|
| 262 |
-
# ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
|
| 263 |
-
expected_channels = 1280
|
| 264 |
-
except:
|
| 265 |
-
expected_channels = 1280
|
| 266 |
-
|
| 267 |
-
actual_channels = x.shape[1]
|
| 268 |
-
|
| 269 |
-
if actual_channels != expected_channels:
|
| 270 |
-
# ์ ํ ๋ณํ์ผ๋ก ์ฑ๋ ์ ์กฐ์
|
| 271 |
-
if not hasattr(self, 'channel_adapter'):
|
| 272 |
-
# channel_adapter๋ฅผ bfloat16์ผ๋ก ์์ฑ
|
| 273 |
-
self.channel_adapter = nn.Linear(actual_channels, expected_channels, dtype=torch.bfloat16).to(x.device)
|
| 274 |
-
|
| 275 |
-
x = x.permute(0, 2, 3, 1) # (1, d, h, w) -> (1, h, w, d)
|
| 276 |
-
# ์
๋ ฅ์ bfloat16์ผ๋ก ๋ณํ (ํ ๋ฒ๋ง)
|
| 277 |
-
if x.dtype != torch.bfloat16:
|
| 278 |
-
x = x.to(torch.bfloat16)
|
| 279 |
-
x = self.channel_adapter(x) # ์ฑ๋ ์ ์กฐ์
|
| 280 |
-
x = x.permute(0, 3, 1, 2) # (1, h, w, d) -> (1, d, h, w)
|
| 281 |
-
|
| 282 |
-
# โจ ์ต์ ํ: ์ด๋ฏธ bfloat16์ผ๋ก ์ด๊ธฐํ๋ ๋ ์ด์ด๋ค ์ฌ์ฉ
|
| 283 |
-
x = self.net[0](x)
|
| 284 |
-
x = self.net[1](x)
|
| 285 |
-
x = self.net[2](x)
|
| 286 |
-
x = rearrange(x, "1 d h w -> (h w) d")
|
| 287 |
-
|
| 288 |
-
# โจ ์ต์ ํ: ์ด๋ฏธ bfloat16์ผ๋ก ์ด๊ธฐํ๋ readout ์ฌ์ฉ
|
| 289 |
-
x = self.readout(x)
|
| 290 |
-
|
| 291 |
-
return x
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
class CustomQwen2VLVE(Qwen2VisionTransformerPretrainedModel):
|
| 295 |
-
config_class = Qwen2VLVisionConfig
|
| 296 |
-
_no_split_modules = ["Qwen2VLVisionBlock"]
|
| 297 |
-
|
| 298 |
-
def __init__(self, config) -> None:
|
| 299 |
-
Qwen2VLPreTrainedModel.__init__(self, config)
|
| 300 |
-
self.spatial_merge_size = config.spatial_merge_size
|
| 301 |
-
self.gradient_checkpointing = False
|
| 302 |
-
|
| 303 |
-
self.patch_embed = PatchEmbed(
|
| 304 |
-
patch_size=config.patch_size,
|
| 305 |
-
temporal_patch_size=config.temporal_patch_size,
|
| 306 |
-
in_channels=config.in_channels,
|
| 307 |
-
embed_dim=config.embed_dim,
|
| 308 |
-
)
|
| 309 |
-
|
| 310 |
-
head_dim = config.embed_dim // config.num_heads
|
| 311 |
-
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 312 |
-
|
| 313 |
-
self.blocks = nn.ModuleList(
|
| 314 |
-
[Qwen2VLVisionBlock(config, config._attn_implementation) for _ in range(config.depth)]
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
def forward(
|
| 318 |
-
self,
|
| 319 |
-
pixel_values: torch.Tensor,
|
| 320 |
-
grid_thw: torch.Tensor,
|
| 321 |
-
output_hidden_states: Optional[bool] = None,
|
| 322 |
-
return_dict: Optional[bool] = None,
|
| 323 |
-
) -> Union[Tuple, BaseModelOutput]:
|
| 324 |
-
assert return_dict, "Only return_dict=True is supported."
|
| 325 |
-
|
| 326 |
-
encoder_states = () if output_hidden_states else None
|
| 327 |
-
|
| 328 |
-
hidden_states = self.patch_embed(pixel_values)
|
| 329 |
-
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 330 |
-
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
| 331 |
-
position_embeddings = emb.cos(), emb.sin()
|
| 332 |
-
|
| 333 |
-
cu_seqlens = torch.repeat_interleave(
|
| 334 |
-
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
| 335 |
-
).cumsum(dim=0, dtype=torch.int32)
|
| 336 |
-
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 337 |
-
|
| 338 |
-
for blk in self.blocks:
|
| 339 |
-
if output_hidden_states:
|
| 340 |
-
encoder_states = encoder_states + (hidden_states,)
|
| 341 |
-
if self.gradient_checkpointing and self.training:
|
| 342 |
-
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 343 |
-
blk.__call__,
|
| 344 |
-
hidden_states=hidden_states,
|
| 345 |
-
cu_seqlens=cu_seqlens,
|
| 346 |
-
position_embeddings=position_embeddings,
|
| 347 |
-
use_reentrant=False,
|
| 348 |
-
)
|
| 349 |
-
else:
|
| 350 |
-
layer_outputs = blk(
|
| 351 |
-
hidden_states=hidden_states,
|
| 352 |
-
cu_seqlens=cu_seqlens,
|
| 353 |
-
position_embeddings=position_embeddings,
|
| 354 |
-
)
|
| 355 |
-
hidden_states = layer_outputs
|
| 356 |
-
if output_hidden_states:
|
| 357 |
-
encoder_states = encoder_states + (hidden_states,)
|
| 358 |
-
|
| 359 |
-
if not return_dict:
|
| 360 |
-
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
| 361 |
-
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states)
|
| 362 |
-
|
| 363 |
-
def get_num_tokens(self):
|
| 364 |
-
return -1
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
class KananaVPreTrainedModel(PreTrainedModel):
|
| 368 |
-
"""
|
| 369 |
-
An abstract class to handle weights initialization and
|
| 370 |
-
a simple interface for downloading and loading pretrained models.
|
| 371 |
-
"""
|
| 372 |
-
|
| 373 |
-
config_class = KananaVConfig
|
| 374 |
-
base_model_prefix = "kanana-1.5-v"
|
| 375 |
-
supports_gradient_checkpointing = True
|
| 376 |
-
_skip_keys_device_placement = "past_key_values"
|
| 377 |
-
_supports_flash_attn_2 = True
|
| 378 |
-
_supports_sdpa = True
|
| 379 |
-
_supports_cache_class = True
|
| 380 |
-
_supports_static_cache = False
|
| 381 |
-
|
| 382 |
-
_keys_to_ignore_on_load_missing = [
|
| 383 |
-
r"position_ids",
|
| 384 |
-
r"language_model.encoder.embed_tokens.weight",
|
| 385 |
-
r"language_model.decoder.embed_tokens.weight",
|
| 386 |
-
r"language_model.lm_head.weight",
|
| 387 |
-
]
|
| 388 |
-
_no_split_modules = [
|
| 389 |
-
"CustomQwen2VLVE",
|
| 390 |
-
"DynamicCAbstractor",
|
| 391 |
-
"LlamaForCausalLM",
|
| 392 |
-
"Parameter",
|
| 393 |
-
]
|
| 394 |
-
|
| 395 |
-
def _init_weights(self, module):
|
| 396 |
-
"""Initialize the weights"""
|
| 397 |
-
if (
|
| 398 |
-
isinstance(module, nn.Conv2d)
|
| 399 |
-
or isinstance(module, nn.Embedding)
|
| 400 |
-
or isinstance(module, nn.Linear)
|
| 401 |
-
):
|
| 402 |
-
module.weight.data.normal_(mean=0.0, std=0.02)
|
| 403 |
-
if hasattr(module, "bias") and module.bias is not None:
|
| 404 |
-
module.bias.data.zero_()
|
| 405 |
-
elif isinstance(module, nn.LayerNorm):
|
| 406 |
-
module.bias.data.zero_()
|
| 407 |
-
module.weight.data.fill_(1.0)
|
| 408 |
-
elif isinstance(module, nn.Parameter):
|
| 409 |
-
raise ValueError()
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
class KananaVForConditionalGeneration(KananaVPreTrainedModel):
|
| 413 |
-
config_class = KananaVConfig
|
| 414 |
-
|
| 415 |
-
def __init__(self, config: KananaVConfig):
|
| 416 |
-
super().__init__(config)
|
| 417 |
-
|
| 418 |
-
logger.info("Build vision model ...")
|
| 419 |
-
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
|
| 420 |
-
|
| 421 |
-
logger.info("Build projector ...")
|
| 422 |
-
self.abstractor = DynamicCAbstractor(config.projector_config,
|
| 423 |
-
num_input_tokens=self.vision_model.get_num_tokens())
|
| 424 |
-
|
| 425 |
-
logger.info("Build language model ...")
|
| 426 |
-
self.language_model = LlamaForCausalLM._from_config(config=config.text_config)
|
| 427 |
-
|
| 428 |
-
self.post_init()
|
| 429 |
-
|
| 430 |
-
def forward_vision(self, pixel_values: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None):
|
| 431 |
-
|
| 432 |
-
# โจ ํต์ฌ ์์ : pixel_values๊ฐ ๋ฆฌ์คํธ์ผ ๊ฒฝ์ฐ์ ํ
์์ผ ๊ฒฝ์ฐ๋ฅผ ๋ชจ๋ ์ฒ๋ฆฌ
|
| 433 |
-
if isinstance(pixel_values, list):
|
| 434 |
-
# ๋ค์ค ์ด๋ฏธ์ง: ๊ฐ ์ด๋ฏธ์ง๋ฅผ ์ฒ๋ฆฌํ์ฌ ๊ฒฐ๊ณผ๋ฅผ ํฉ์นจ
|
| 435 |
-
visual_features_list = []
|
| 436 |
-
for i, pv in enumerate(pixel_values):
|
| 437 |
-
single_image_metas = {k: v[i] for k, v in image_metas.items()}
|
| 438 |
-
|
| 439 |
-
# grid_thw ์ฒ๋ฆฌ ์์
|
| 440 |
-
vision_grid_thw = single_image_metas["vision_grid_thw"]
|
| 441 |
-
if isinstance(vision_grid_thw, (list, tuple)):
|
| 442 |
-
# ํํ์ ๋ฆฌ์คํธ๋ก ๋ณํํ์ฌ ํ
์ ์์ฑ
|
| 443 |
-
grid_thw = torch.tensor([list(vision_grid_thw)]).to(pv.device)
|
| 444 |
-
else:
|
| 445 |
-
grid_thw = torch.tensor([vision_grid_thw]).to(pv.device)
|
| 446 |
-
|
| 447 |
-
# โจ ์ต์ ํ: ๋ถํ์ํ dtype ๋ณํ ์ ๊ฑฐ
|
| 448 |
-
v_outputs = self.vision_model(
|
| 449 |
-
pixel_values=pv.unsqueeze(0),
|
| 450 |
-
grid_thw=grid_thw,
|
| 451 |
-
return_dict=True, output_hidden_states=True
|
| 452 |
-
)
|
| 453 |
-
layer_index = self.config.projector_config.feature_layer_index
|
| 454 |
-
visual_features_list.append(self._get_visual_feature_at(v_outputs.hidden_states, layer_index))
|
| 455 |
-
# ๊ฒฝ๋ ๋ฉํฐ๋ชจ๋ฌ: ์ฒซ ๋ฒ์งธ ๊ฒ๋ง ์ฌ์ฉํ๋, downstream์ด ๋จ์ผ ํ
์๋ฅผ ๊ธฐ๋ํ๋ฏ๋ก ํ
์๋ง ๋ฐํ
|
| 456 |
-
return visual_features_list[0] if len(visual_features_list) > 0 else visual_features_list
|
| 457 |
-
else:
|
| 458 |
-
# ๋จ์ผ ์ด๋ฏธ์ง - ์ด๋ฏธ ๋ถ๋ฆฌ๋ ํน์ง ํ
์
|
| 459 |
-
|
| 460 |
-
# grid_thw๊ฐ ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ
|
| 461 |
-
grid_thw = image_metas["vision_grid_thw"]
|
| 462 |
-
if isinstance(grid_thw, list):
|
| 463 |
-
grid_thw = grid_thw[0]
|
| 464 |
-
|
| 465 |
-
# grid_thw๋ฅผ ํ
์๋ก ๋ณํ
|
| 466 |
-
if not isinstance(grid_thw, torch.Tensor):
|
| 467 |
-
if isinstance(grid_thw, (list, tuple)):
|
| 468 |
-
# ํํ์ ๋ฆฌ์คํธ๋ก ๋ณํํ์ฌ ํ
์ ์์ฑ
|
| 469 |
-
grid_thw = torch.tensor([list(grid_thw)])
|
| 470 |
-
else:
|
| 471 |
-
grid_thw = torch.tensor([grid_thw])
|
| 472 |
-
|
| 473 |
-
# ๋๋ฐ์ด์ค ์ ๋ณด ์ถ๊ฐ
|
| 474 |
-
if hasattr(pixel_values, 'device'):
|
| 475 |
-
grid_thw = grid_thw.to(pixel_values.device)
|
| 476 |
-
|
| 477 |
-
# pixel_values๊ฐ 2D ํน์ง ํ
์์ธ ๊ฒฝ์ฐ vision_model์ ํตํด ์ฒ๋ฆฌ
|
| 478 |
-
if len(pixel_values.shape) == 2:
|
| 479 |
-
# 2D ํน์ง ํ
์๋ฅผ vision_model์ด ์ฒ๋ฆฌํ ์ ์๋ ํํ๋ก ๋ณํ
|
| 480 |
-
# ๋ค์ค ์ด๋ฏธ์ง์ ๋์ผํ ๋ฐฉ์์ผ๋ก ์ฒ๋ฆฌํ๋, ์ฌ๋ฐ๋ฅธ ํํ๋ก ๋ณํ
|
| 481 |
-
|
| 482 |
-
# pixel_values๋ฅผ (1, 3, H, W) ํํ๋ก ์ฌ๊ตฌ์ฑ
|
| 483 |
-
# ๋ค์ค ์ด๋ฏธ์ง์์ ์ฌ์ฉํ๋ ๋ฐฉ์๊ณผ ๋์ผํ๊ฒ ์ฒ๋ฆฌ
|
| 484 |
-
if len(pixel_values.shape) == 2:
|
| 485 |
-
# 2D ํ
์๋ฅผ vision_model์ด ์ฒ๋ฆฌํ ์ ์๋ ํํ๋ก ๋ณํ
|
| 486 |
-
# ๋ค์ค ์ด๋ฏธ์ง์์๋ ์ด๋ฏธ ์ฌ๋ฐ๋ฅธ ํํ๋ก ์ ๋ฌ๋๋ฏ๋ก ๋์ผํ๊ฒ ์ฒ๋ฆฌ
|
| 487 |
-
|
| 488 |
-
# โจ ์ต์ ํ: ๋ถํ์ํ dtype ๋ณํ ์ ๊ฑฐ
|
| 489 |
-
v_outputs = self.vision_model(
|
| 490 |
-
pixel_values=pixel_values,
|
| 491 |
-
grid_thw=grid_thw,
|
| 492 |
-
return_dict=True, output_hidden_states=True
|
| 493 |
-
)
|
| 494 |
-
layer_index = self.config.projector_config.feature_layer_index
|
| 495 |
-
return self._get_visual_feature_at(v_outputs.hidden_states, layer_index)
|
| 496 |
-
else:
|
| 497 |
-
return pixel_values
|
| 498 |
-
|
| 499 |
-
# โจ ์ต์ ํ: ๋ถํ์ํ dtype ๋ณํ ์ ๊ฑฐ
|
| 500 |
-
v_outputs = self.vision_model(
|
| 501 |
-
pixel_values=pixel_values,
|
| 502 |
-
grid_thw=grid_thw,
|
| 503 |
-
return_dict=True, output_hidden_states=True
|
| 504 |
-
)
|
| 505 |
-
layer_index = self.config.projector_config.feature_layer_index
|
| 506 |
-
return self._get_visual_feature_at(v_outputs.hidden_states, layer_index)
|
| 507 |
-
|
| 508 |
-
def forward_projector(self, visual_features: Union[torch.Tensor, List[torch.Tensor]], image_metas: Optional[dict] = None):
|
| 509 |
-
print(f"๐ forward_projector - visual_features ํํ: {visual_features.shape if hasattr(visual_features, 'shape') else type(visual_features)}")
|
| 510 |
-
|
| 511 |
-
# โจ ํต์ฌ ์์ : visual_features๊ฐ ๋ฆฌ์คํธ์ผ ๊ฒฝ์ฐ ์ฒ๋ฆฌ
|
| 512 |
-
if isinstance(visual_features, list):
|
| 513 |
-
visual_embeds_list = []
|
| 514 |
-
for i, vf in enumerate(visual_features):
|
| 515 |
-
single_image_metas = {k: v[i] for k, v in image_metas.items()}
|
| 516 |
-
vision_grid_thw = single_image_metas["vision_grid_thw"]
|
| 517 |
-
if isinstance(vision_grid_thw, (list, tuple)):
|
| 518 |
-
grid_thw = torch.tensor([list(vision_grid_thw)]).to(vf.device)
|
| 519 |
-
else:
|
| 520 |
-
grid_thw = torch.tensor([vision_grid_thw]).to(vf.device)
|
| 521 |
-
visual_embeds = self.abstractor(vf, grid_thw=grid_thw)["last_hidden_state"]
|
| 522 |
-
visual_embeds_list.append(visual_embeds)
|
| 523 |
-
return torch.cat(visual_embeds_list, dim=0)
|
| 524 |
-
else:
|
| 525 |
-
# ๋จ์ผ ์ด๋ฏธ์ง
|
| 526 |
-
print(f"๐ forward_projector - ๋จ์ผ ํ
์ ์ฒ๋ฆฌ")
|
| 527 |
-
|
| 528 |
-
# visual_features๊ฐ ์ด๋ฏธ ์ฒ๋ฆฌ๋ ํน์ง ํ
์์ธ ๊ฒฝ์ฐ
|
| 529 |
-
if len(visual_features.shape) == 2:
|
| 530 |
-
print(f"๐ forward_projector - ์ด๋ฏธ ์ฒ๋ฆฌ๋ ํน์ง ํ
์ ๊ฐ์ง")
|
| 531 |
-
print(f"๐ forward_projector - ํน์ง ํ
์ ํํ: {visual_features.shape}")
|
| 532 |
-
|
| 533 |
-
# grid_thw๊ฐ ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ
|
| 534 |
-
grid_thw = image_metas["vision_grid_thw"]
|
| 535 |
-
if isinstance(grid_thw, list):
|
| 536 |
-
grid_thw = grid_thw[0]
|
| 537 |
-
|
| 538 |
-
# grid_thw๋ฅผ ํ
์๋ก ๋ณํ
|
| 539 |
-
if not isinstance(grid_thw, torch.Tensor):
|
| 540 |
-
if isinstance(grid_thw, (list, tuple)):
|
| 541 |
-
# ํํ์ ๋ฆฌ์คํธ๋ก ๋ณํํ์ฌ ํ
์ ์์ฑ
|
| 542 |
-
grid_thw = torch.tensor([list(grid_thw)])
|
| 543 |
-
else:
|
| 544 |
-
grid_thw = torch.tensor([grid_thw])
|
| 545 |
-
|
| 546 |
-
# ๋๋ฐ์ด์ค ์ ๋ณด ์ถ๊ฐ
|
| 547 |
-
if hasattr(visual_features, 'device'):
|
| 548 |
-
grid_thw = grid_thw.to(visual_features.device)
|
| 549 |
-
|
| 550 |
-
print(f"๐ forward_projector - grid_thw: {grid_thw}")
|
| 551 |
-
print(f"๐ forward_projector - grid_thw ๊ณ์ฐ๋ ํ ํฐ ์: {torch.prod(grid_thw, dim=1)}")
|
| 552 |
-
print(f"๐ forward_projector - ์ค์ ํน์ง ํ
์ ํ ํฐ ์: {visual_features.shape[0]}")
|
| 553 |
-
|
| 554 |
-
# grid_thw๊ฐ ์ค์ ํ ํฐ ์์ ๋ง์ง ์๋ ๊ฒฝ์ฐ ์์
|
| 555 |
-
actual_tokens = visual_features.shape[0]
|
| 556 |
-
if torch.prod(grid_thw, dim=1).item() != actual_tokens:
|
| 557 |
-
print(f"๐ forward_projector - grid_thw ์์ ํ์")
|
| 558 |
-
# ์ค์ ํ ํฐ ์์ ๋ง๋ grid_thw ๊ณ์ฐ
|
| 559 |
-
# ์ด๋ฏธ์ง์ ์ค์ ๋น์จ์ ๊ณ ๋ คํ์ฌ ๊ณ์ฐ
|
| 560 |
-
T = 1
|
| 561 |
-
|
| 562 |
-
# ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ์์ ์ค์ ํฌ๊ธฐ ์ ๋ณด ์ฌ์ฉ
|
| 563 |
-
if 'hw_orig_resolution' in image_metas:
|
| 564 |
-
orig_h, orig_w = image_metas['hw_orig_resolution']
|
| 565 |
-
if isinstance(orig_h, list):
|
| 566 |
-
orig_h = orig_h[0] if isinstance(orig_h[0], (int, float)) else orig_h[0][0]
|
| 567 |
-
if isinstance(orig_w, list):
|
| 568 |
-
orig_w = orig_w[0] if isinstance(orig_w[0], (int, float)) else orig_w[0][0]
|
| 569 |
-
|
| 570 |
-
# ์ค์ ๋น์จ์ ์ ์งํ๋ฉด์ ํ ํฐ ์์ ๋ง๊ฒ ์กฐ์
|
| 571 |
-
aspect_ratio = orig_w / orig_h
|
| 572 |
-
H = int((actual_tokens / aspect_ratio) ** 0.5)
|
| 573 |
-
W = int(actual_tokens / H)
|
| 574 |
-
|
| 575 |
-
# ํ ํฐ ์๊ฐ ๋ง์ง ์์ผ๋ฉด ์กฐ์
|
| 576 |
-
while H * W != actual_tokens and H > 1 and W > 1:
|
| 577 |
-
if H * W > actual_tokens:
|
| 578 |
-
H -= 1
|
| 579 |
-
W = int(actual_tokens / H)
|
| 580 |
-
else:
|
| 581 |
-
W += 1
|
| 582 |
-
H = int(actual_tokens / W)
|
| 583 |
-
else:
|
| 584 |
-
# ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
|
| 585 |
-
H = int(actual_tokens ** 0.5)
|
| 586 |
-
W = actual_tokens // H
|
| 587 |
-
if actual_tokens % H != 0:
|
| 588 |
-
W += 1
|
| 589 |
-
|
| 590 |
-
corrected_grid_thw = torch.tensor([[T, H, W]])
|
| 591 |
-
print(f"๐ forward_projector - ์์ ๋ grid_thw: {corrected_grid_thw}")
|
| 592 |
-
print(f"๐ forward_projector - ์์ ๋ ํ ํฐ ์: {torch.prod(corrected_grid_thw, dim=1)}")
|
| 593 |
-
|
| 594 |
-
# ํ ํฐ ์๊ฐ ๋ง์ง ์๋ ๊ฒฝ์ฐ ํจ๋ฉ ๋๋ ์๋ฅด๊ธฐ
|
| 595 |
-
expected_tokens = torch.prod(corrected_grid_thw, dim=1).item()
|
| 596 |
-
if expected_tokens > actual_tokens:
|
| 597 |
-
# ํจ๋ฉ
|
| 598 |
-
padding_size = expected_tokens - actual_tokens
|
| 599 |
-
padding = torch.zeros(padding_size, visual_features.shape[1], device=visual_features.device)
|
| 600 |
-
visual_features = torch.cat([visual_features, padding], dim=0)
|
| 601 |
-
print(f"๐ forward_projector - ํจ๋ฉ ์ถ๊ฐ: {padding_size}๊ฐ ํ ํฐ")
|
| 602 |
-
elif expected_tokens < actual_tokens:
|
| 603 |
-
# ์๋ฅด๊ธฐ
|
| 604 |
-
visual_features = visual_features[:expected_tokens]
|
| 605 |
-
print(f"๐ forward_projector - ํ ํฐ ์๋ฅด๊ธฐ: {expected_tokens}๊ฐ๋ก")
|
| 606 |
-
|
| 607 |
-
grid_thw = corrected_grid_thw
|
| 608 |
-
|
| 609 |
-
# ํน์ง ํ
์๋ฅผ abstractor์ ์ง์ ์ ๋ฌ
|
| 610 |
-
visual_embeds = self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"]
|
| 611 |
-
print(f"๐ forward_projector - abstractor ์ถ๋ ฅ ํํ: {visual_embeds.shape}")
|
| 612 |
-
return visual_embeds
|
| 613 |
-
else:
|
| 614 |
-
# ์ผ๋ฐ์ ์ธ vision model ์ถ๋ ฅ
|
| 615 |
-
grid_thw = image_metas["vision_grid_thw"]
|
| 616 |
-
return self.abstractor(visual_features, grid_thw=grid_thw)["last_hidden_state"]
|
| 617 |
-
|
| 618 |
-
def forward_and_project_vision(self, pixel_values, image_metas: Optional[dict] = None):
|
| 619 |
-
visual_features = self.forward_vision(pixel_values, image_metas=image_metas)
|
| 620 |
-
visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
|
| 621 |
-
return visual_embeds
|
| 622 |
-
|
| 623 |
-
def _get_visual_feature_at(self, v_output, layer_index):
|
| 624 |
-
if isinstance(layer_index, list):
|
| 625 |
-
visual_features = torch.stack(v_output, dim=1)[:, layer_index] # [B, n_scales, L, dim]
|
| 626 |
-
else:
|
| 627 |
-
visual_features = v_output[layer_index] # [B, L, dim]
|
| 628 |
-
return visual_features
|
| 629 |
-
|
| 630 |
-
def embed_text_tokens(self, input_ids):
|
| 631 |
-
"""Embed input_ids into text_embeds, ignoring media tokens (negative values)."""
|
| 632 |
-
input_ids = input_ids.clone()
|
| 633 |
-
input_ids[input_ids < 0] = 0
|
| 634 |
-
|
| 635 |
-
text_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 636 |
-
if hasattr(self.language_model, "transformer") and hasattr(
|
| 637 |
-
self.language_model.transformer, "word_embeddings_layernorm"
|
| 638 |
-
):
|
| 639 |
-
text_embeds = self.language_model.transformer.word_embeddings_layernorm(text_embeds)
|
| 640 |
-
|
| 641 |
-
return text_embeds
|
| 642 |
-
|
| 643 |
-
def prepare_mm_inputs(
|
| 644 |
-
self,
|
| 645 |
-
input_ids: torch.FloatTensor,
|
| 646 |
-
pixel_values: Optional[list[torch.FloatTensor]] = None,
|
| 647 |
-
image_metas: Optional[dict] = None,
|
| 648 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 649 |
-
):
|
| 650 |
-
"""Prepare multimodal inputs from input_ids and pixel_values."""
|
| 651 |
-
if pixel_values is not None:
|
| 652 |
-
# pixel_values๊ฐ ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ ๊ฐ๊ฐ์ ๋ณํ
|
| 653 |
-
if isinstance(pixel_values, list):
|
| 654 |
-
pixel_values = [pv.to(self._get_input_dtype()) for pv in pixel_values]
|
| 655 |
-
else:
|
| 656 |
-
pixel_values = pixel_values.to(self._get_input_dtype())
|
| 657 |
-
|
| 658 |
-
if attention_mask is None:
|
| 659 |
-
attention_mask = input_ids.new_ones(*input_ids.shape)
|
| 660 |
-
|
| 661 |
-
# Get Text Embeddings
|
| 662 |
-
text_embeds = self.embed_text_tokens(input_ids)
|
| 663 |
-
flattened_text_embeds = rearrange(text_embeds, "b l d -> (b l) d")
|
| 664 |
-
flattened_input_ids = rearrange(input_ids, "b l -> (b l)")
|
| 665 |
-
|
| 666 |
-
# Get Visual Embeddings
|
| 667 |
-
if pixel_values is not None:
|
| 668 |
-
print(f"๐ prepare_mm_inputs - pixel_values ํ์
: {type(pixel_values)}")
|
| 669 |
-
if hasattr(pixel_values, 'shape'):
|
| 670 |
-
print(f"๐ prepare_mm_inputs - pixel_values ํํ: {pixel_values.shape}")
|
| 671 |
-
if isinstance(pixel_values, list):
|
| 672 |
-
print(f"๐ prepare_mm_inputs - pixel_values ๊ธธ์ด: {len(pixel_values)}")
|
| 673 |
-
|
| 674 |
-
# ๋ค์ค ์ด๋ฏธ์ง ์ฒ๋ฆฌ: ๊ฐ ์ด๋ฏธ์ง๋ฅผ ๊ฐ๋ณ์ ์ผ๋ก ์ฒ๋ฆฌ
|
| 675 |
-
if isinstance(pixel_values, list) and len(pixel_values) > 1:
|
| 676 |
-
print(f"๐ prepare_mm_inputs - ๋ค์ค ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์์")
|
| 677 |
-
visual_embeds_list = []
|
| 678 |
-
for i, single_pixel_values in enumerate(pixel_values):
|
| 679 |
-
print(f"๐ prepare_mm_inputs - ์ด๋ฏธ์ง {i} ์ฒ๋ฆฌ ์ค")
|
| 680 |
-
# ๊ฐ ์ด๋ฏธ์ง์ ๋ํ ๊ฐ๋ณ image_metas ์์ฑ
|
| 681 |
-
single_image_metas = {}
|
| 682 |
-
for key, value_list in image_metas.items():
|
| 683 |
-
if isinstance(value_list, list):
|
| 684 |
-
single_image_metas[key] = value_list[i]
|
| 685 |
-
else:
|
| 686 |
-
single_image_metas[key] = value_list
|
| 687 |
-
|
| 688 |
-
# ๊ฐ๋ณ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 689 |
-
single_visual_embeds = self.forward_and_project_vision(
|
| 690 |
-
single_pixel_values.unsqueeze(0), single_image_metas
|
| 691 |
-
)
|
| 692 |
-
visual_embeds_list.append(single_visual_embeds)
|
| 693 |
-
|
| 694 |
-
# ๋ชจ๋ ์ด๋ฏธ์ง์ visual embeds๋ฅผ ์ฐ๊ฒฐ
|
| 695 |
-
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
|
| 696 |
-
print(f"๐ prepare_mm_inputs - ๋ค์ค ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์๋ฃ, ์ฐ๊ฒฐ๋ embeds ํฌ๊ธฐ: {flattened_visual_embeds.shape}")
|
| 697 |
-
else:
|
| 698 |
-
# ๋จ์ผ ์ด๋ฏธ์ง ์ฒ๋ฆฌ (๊ธฐ์กด ๋ฐฉ์)
|
| 699 |
-
print(f"๐ prepare_mm_inputs - ๋จ์ผ ์ด๋ฏธ์ง ์ฒ๋ฆฌ")
|
| 700 |
-
|
| 701 |
-
# pixel_values๊ฐ ์ด๋ฏธ ์ฒ๋ฆฌ๋ ํน์ง ํ
์์ธ ๊ฒฝ์ฐ (๋ค์ค ์ด๋ฏธ์ง ๊ฒฐํฉ)
|
| 702 |
-
if hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 2:
|
| 703 |
-
print(f"๐ prepare_mm_inputs - ์ฒ๋ฆฌ๋ ํน์ง ํ
์ ๊ฐ์ง, ๋ค์ค ์ด๋ฏธ์ง๋ก ๋ถ๋ฆฌ ์๋")
|
| 704 |
-
|
| 705 |
-
# image_metas์์ ์ด๋ฏธ์ง ๊ฐ์ ํ์ธ
|
| 706 |
-
num_images = 0
|
| 707 |
-
if isinstance(image_metas, dict) and "image_token_thw" in image_metas:
|
| 708 |
-
num_images = len(image_metas["image_token_thw"])
|
| 709 |
-
print(f"๐ prepare_mm_inputs - ๊ฐ์ง๋ ์ด๋ฏธ์ง ๊ฐ์: {num_images}")
|
| 710 |
-
|
| 711 |
-
if num_images > 1:
|
| 712 |
-
print(f"๐ prepare_mm_inputs - {num_images}๊ฐ ์ด๋ฏธ์ง๋ก ๋ถ๋ฆฌ ์ฒ๋ฆฌ")
|
| 713 |
-
visual_embeds_list = []
|
| 714 |
-
|
| 715 |
-
# ๊ฐ ์ด๋ฏธ์ง์ ์ค์ ํ ํฐ ์ ๊ณ์ฐ
|
| 716 |
-
current_idx = 0
|
| 717 |
-
for i in range(num_images):
|
| 718 |
-
print(f"๐ prepare_mm_inputs - ์ด๋ฏธ์ง {i} ์ฒ๋ฆฌ ์ค")
|
| 719 |
-
|
| 720 |
-
# ๊ฐ ์ด๋ฏธ์ง์ ๋ํ ๊ฐ๋ณ image_metas ์์ฑ
|
| 721 |
-
single_image_metas = {}
|
| 722 |
-
for key, value_list in image_metas.items():
|
| 723 |
-
if isinstance(value_list, list):
|
| 724 |
-
single_image_metas[key] = value_list[i]
|
| 725 |
-
else:
|
| 726 |
-
single_image_metas[key] = value_list
|
| 727 |
-
|
| 728 |
-
# image_token_thw์์ ์ค์ ํ ํฐ ์ ๊ณ์ฐ
|
| 729 |
-
if "image_token_thw" in single_image_metas:
|
| 730 |
-
token_thw = single_image_metas["image_token_thw"]
|
| 731 |
-
if isinstance(token_thw, (list, tuple)):
|
| 732 |
-
tokens_per_image = int(token_thw[0]) * int(token_thw[1]) * int(token_thw[2])
|
| 733 |
-
elif hasattr(token_thw, 'tolist'):
|
| 734 |
-
tlist = token_thw.tolist()
|
| 735 |
-
tokens_per_image = int(tlist[0]) * int(tlist[1]) * int(tlist[2])
|
| 736 |
-
else:
|
| 737 |
-
tokens_per_image = int(token_thw)
|
| 738 |
-
print(f"๐ prepare_mm_inputs - ์ด๋ฏธ์ง {i} ์ค์ ํ ํฐ ์: {tokens_per_image}")
|
| 739 |
-
else:
|
| 740 |
-
# ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
|
| 741 |
-
tokens_per_image = pixel_values.shape[0] // num_images
|
| 742 |
-
print(f"๐ prepare_mm_inputs - ์ด๋ฏธ์ง {i} ๊ธฐ๋ณธ ํ ํฐ ์: {tokens_per_image}")
|
| 743 |
-
|
| 744 |
-
# pixel_values์์ ํด๋น ์ด๋ฏธ์ง ๋ถ๋ถ ์ถ์ถ
|
| 745 |
-
start_idx = current_idx
|
| 746 |
-
end_idx = current_idx + tokens_per_image
|
| 747 |
-
single_pixel_values = pixel_values[start_idx:end_idx]
|
| 748 |
-
|
| 749 |
-
print(f"๐ prepare_mm_inputs - ์ด๋ฏธ์ง {i} ํน์ง ํํ: {single_pixel_values.shape}")
|
| 750 |
-
|
| 751 |
-
# ๊ฐ๋ณ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 752 |
-
single_visual_embeds = self.forward_and_project_vision(
|
| 753 |
-
single_pixel_values, single_image_metas
|
| 754 |
-
)
|
| 755 |
-
visual_embeds_list.append(single_visual_embeds)
|
| 756 |
-
|
| 757 |
-
current_idx += tokens_per_image
|
| 758 |
-
|
| 759 |
-
# ๋ชจ๋ ์ด๋ฏธ์ง์ visual embeds๋ฅผ ์ฐ๊ฒฐ
|
| 760 |
-
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
|
| 761 |
-
print(f"๐ prepare_mm_inputs - ๋ค์ค ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์๋ฃ, ์ฐ๊ฒฐ๋ embeds ํฌ๊ธฐ: {flattened_visual_embeds.shape}")
|
| 762 |
-
else:
|
| 763 |
-
# ๋จ์ผ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 764 |
-
print(f"๐ prepare_mm_inputs - ๋จ์ผ ์ด๋ฏธ์ง๋ก ์ฒ๋ฆฌ")
|
| 765 |
-
flattened_visual_embeds = self.forward_and_project_vision(
|
| 766 |
-
pixel_values, image_metas
|
| 767 |
-
)
|
| 768 |
-
# pixel_values๊ฐ ๋ฐฐ์น ํํ์ธ ๊ฒฝ์ฐ ๊ฐ๋ณ ์ด๋ฏธ์ง๋ก ๋ถ๋ฆฌ
|
| 769 |
-
elif hasattr(pixel_values, 'shape') and len(pixel_values.shape) == 4 and pixel_values.shape[0] > 1:
|
| 770 |
-
print(f"๐ prepare_mm_inputs - ๋ฐฐ์น ํํ ๊ฐ์ง, ๊ฐ๋ณ ์ด๋ฏธ์ง๋ก ๋ถ๋ฆฌ")
|
| 771 |
-
visual_embeds_list = []
|
| 772 |
-
for i in range(pixel_values.shape[0]):
|
| 773 |
-
print(f"๐ prepare_mm_inputs - ๋ฐฐ์น ์ด๋ฏธ์ง {i} ์ฒ๋ฆฌ ์ค")
|
| 774 |
-
# ๊ฐ ์ด๋ฏธ์ง์ ๋ํ ๊ฐ๋ณ image_metas ์์ฑ
|
| 775 |
-
single_image_metas = {}
|
| 776 |
-
for key, value_list in image_metas.items():
|
| 777 |
-
if isinstance(value_list, list):
|
| 778 |
-
single_image_metas[key] = value_list[i]
|
| 779 |
-
else:
|
| 780 |
-
single_image_metas[key] = value_list
|
| 781 |
-
|
| 782 |
-
# ๊ฐ๋ณ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 783 |
-
if isinstance(pixel_values, list):
|
| 784 |
-
single_pixel_values = pixel_values[i:i+1]
|
| 785 |
-
else:
|
| 786 |
-
# pixel_values๊ฐ ํ
์์ธ ๊ฒฝ์ฐ
|
| 787 |
-
single_pixel_values = pixel_values[i:i+1]
|
| 788 |
-
|
| 789 |
-
single_visual_embeds = self.forward_and_project_vision(
|
| 790 |
-
single_pixel_values, single_image_metas
|
| 791 |
-
)
|
| 792 |
-
visual_embeds_list.append(single_visual_embeds)
|
| 793 |
-
|
| 794 |
-
# ๋ชจ๋ ์ด๋ฏธ์ง์ visual embeds๋ฅผ ์ฐ๊ฒฐ
|
| 795 |
-
flattened_visual_embeds = torch.cat(visual_embeds_list, dim=0)
|
| 796 |
-
print(f"๐ prepare_mm_inputs - ๋ค์ค ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์๋ฃ, ์ฐ๊ฒฐ๋ embeds ํฌ๊ธฐ: {flattened_visual_embeds.shape}")
|
| 797 |
-
|
| 798 |
-
# ๊ฐ ์ด๋ฏธ์ง์ embeds ํฌ๊ธฐ ์ถ๋ ฅ
|
| 799 |
-
for i, embeds in enumerate(visual_embeds_list):
|
| 800 |
-
print(f"๐ prepare_mm_inputs - ์ด๋ฏธ์ง {i} embeds ํฌ๊ธฐ: {embeds.shape}")
|
| 801 |
-
else:
|
| 802 |
-
# ๋จ์ผ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 803 |
-
# image_metas๊ฐ ๋ค์ค ์ด๋ฏธ์ง ์ ๋ณด๋ฅผ ํฌํจํ๋ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง ์ ๋ณด๋ง ์ฌ์ฉ
|
| 804 |
-
if isinstance(image_metas, dict):
|
| 805 |
-
single_image_metas = {}
|
| 806 |
-
for key, value_list in image_metas.items():
|
| 807 |
-
if isinstance(value_list, list):
|
| 808 |
-
single_image_metas[key] = value_list[0] # ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง ์ ๋ณด ์ฌ์ฉ
|
| 809 |
-
else:
|
| 810 |
-
single_image_metas[key] = value_list
|
| 811 |
-
print(f"๐ prepare_mm_inputs - ๋จ์ผ ์ด๋ฏธ์ง ์ฒ๋ฆฌ, ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง ์ ๋ณด ์ฌ์ฉ")
|
| 812 |
-
else:
|
| 813 |
-
single_image_metas = image_metas
|
| 814 |
-
|
| 815 |
-
# ๋จ์ผ ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์ pixel_values๊ฐ ๋ฆฌ์คํธ์ธ์ง ํ์ธ
|
| 816 |
-
if isinstance(pixel_values, list):
|
| 817 |
-
single_pixel_values = pixel_values[0] # ์ฒซ ๋ฒ์งธ ์ด๋ฏธ์ง๋ง ์ฌ์ฉ
|
| 818 |
-
else:
|
| 819 |
-
single_pixel_values = pixel_values
|
| 820 |
-
|
| 821 |
-
flattened_visual_embeds = self.forward_and_project_vision(
|
| 822 |
-
single_pixel_values, single_image_metas
|
| 823 |
-
)
|
| 824 |
-
|
| 825 |
-
# dtype ์ผ์น๋ฅผ ์ํด visual_embeds๋ฅผ text_embeds์ ๊ฐ์ dtype์ผ๋ก ๋ณํ
|
| 826 |
-
flattened_visual_embeds = flattened_visual_embeds.to(flattened_text_embeds.dtype)
|
| 827 |
-
|
| 828 |
-
# visual embeds์ -1 ํ ํฐ ๊ฐ์ ํ์ธ ๋ฐ ์กฐ์
|
| 829 |
-
num_visual_tokens = flattened_visual_embeds.shape[0]
|
| 830 |
-
num_neg_one_tokens = (flattened_input_ids == -1).sum().item()
|
| 831 |
-
if num_neg_one_tokens == 0:
|
| 832 |
-
# -1 ํ ํฐ์ด ์์ผ๋ฉด ๋ฌธ์ฅ ์์๋ถ์ ์๊ฐ ํ ํฐ์ ๊ฐ์ ์ฝ์
ํ๊ธฐ ์ํด ๊ฐ์ง -1 ํ ํฐ ํ๋ ์ถ๊ฐ
|
| 833 |
-
fake_neg = torch.full_like(flattened_input_ids[:1], -1)
|
| 834 |
-
flattened_input_ids = torch.cat([fake_neg, flattened_input_ids], dim=0)
|
| 835 |
-
num_neg_one_tokens = 1
|
| 836 |
-
print(f"๐ prepare_mm_inputs - visual embeds ๊ฐ์: {num_visual_tokens}")
|
| 837 |
-
print(f"๐ prepare_mm_inputs - -1 ํ ํฐ ๊ฐ์: {num_neg_one_tokens}")
|
| 838 |
-
|
| 839 |
-
if num_visual_tokens != num_neg_one_tokens:
|
| 840 |
-
print(f"๐ prepare_mm_inputs - ํ ํฐ ๊ฐ์ ๋ถ์ผ์น, ์กฐ์ ํ์")
|
| 841 |
-
if num_visual_tokens > num_neg_one_tokens:
|
| 842 |
-
# visual embeds๊ฐ ๋ง์ผ๋ฉด ์๋ฅด๊ธฐ
|
| 843 |
-
flattened_visual_embeds = flattened_visual_embeds[:num_neg_one_tokens]
|
| 844 |
-
print(f"๐ prepare_mm_inputs - visual embeds ์๋ฅด๊ธฐ: {num_visual_tokens} -> {num_neg_one_tokens}")
|
| 845 |
-
else:
|
| 846 |
-
# visual embeds๊ฐ ์ ์ผ๋ฉด ๋ฐ๋ณตํด์ ์ฌ์ฉ
|
| 847 |
-
repeat_times = num_neg_one_tokens // num_visual_tokens
|
| 848 |
-
remainder = num_neg_one_tokens % num_visual_tokens
|
| 849 |
-
|
| 850 |
-
if repeat_times > 0:
|
| 851 |
-
# visual embeds๋ฅผ ๋ฐ๋ณต
|
| 852 |
-
repeated_embeds = flattened_visual_embeds.repeat(repeat_times, 1)
|
| 853 |
-
if remainder > 0:
|
| 854 |
-
# ๋๋จธ์ง ๋ถ๋ถ ์ถ๊ฐ
|
| 855 |
-
remainder_embeds = flattened_visual_embeds[:remainder]
|
| 856 |
-
repeated_embeds = torch.cat([repeated_embeds, remainder_embeds], dim=0)
|
| 857 |
-
flattened_visual_embeds = repeated_embeds
|
| 858 |
-
else:
|
| 859 |
-
# visual embeds๊ฐ ๋๋ฌด ์ ์ผ๋ฉด ์ฒซ ๋ฒ์งธ ํ ํฐ์ ๋ฐ๋ณต
|
| 860 |
-
# ์ต์ 1๊ฐ๋ผ๋ ์ ์ง
|
| 861 |
-
base = flattened_visual_embeds[0:1]
|
| 862 |
-
flattened_visual_embeds = base.repeat(max(1, num_neg_one_tokens), 1)
|
| 863 |
-
|
| 864 |
-
print(f"๐ prepare_mm_inputs - visual embeds ๋ฐ๋ณต: {num_visual_tokens} -> {num_neg_one_tokens}")
|
| 865 |
-
|
| 866 |
-
flattened_text_embeds[flattened_input_ids == -1] = flattened_visual_embeds
|
| 867 |
-
|
| 868 |
-
input_embeds = rearrange(
|
| 869 |
-
flattened_text_embeds, "(b l) d -> b l d", b=input_ids.shape[0]
|
| 870 |
-
)
|
| 871 |
-
|
| 872 |
-
return_inputs = {
|
| 873 |
-
"inputs_embeds": input_embeds,
|
| 874 |
-
"attention_mask": attention_mask,
|
| 875 |
-
}
|
| 876 |
-
return return_inputs
|
| 877 |
-
|
| 878 |
-
def forward(
|
| 879 |
-
self,
|
| 880 |
-
pixel_values: list[torch.FloatTensor],
|
| 881 |
-
image_metas: dict[list],
|
| 882 |
-
input_ids: torch.FloatTensor,
|
| 883 |
-
seq_length: Optional[torch.LongTensor] = None,
|
| 884 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 885 |
-
labels: Optional[torch.LongTensor] = None,
|
| 886 |
-
return_dict: Optional[bool] = None,
|
| 887 |
-
):
|
| 888 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 889 |
-
inputs = self.prepare_mm_inputs(
|
| 890 |
-
input_ids=input_ids,
|
| 891 |
-
pixel_values=pixel_values,
|
| 892 |
-
image_metas=image_metas,
|
| 893 |
-
attention_mask=attention_mask,
|
| 894 |
-
)
|
| 895 |
-
|
| 896 |
-
outputs = self.language_model(
|
| 897 |
-
**inputs,
|
| 898 |
-
labels=labels,
|
| 899 |
-
position_ids=None,
|
| 900 |
-
return_dict=return_dict,
|
| 901 |
-
output_attentions=self.config.output_attentions,
|
| 902 |
-
)
|
| 903 |
-
|
| 904 |
-
return outputs
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
@torch.no_grad()
|
| 908 |
-
def generate(
|
| 909 |
-
self,
|
| 910 |
-
pixel_values: torch.FloatTensor = None,
|
| 911 |
-
image_metas: dict[list] = None,
|
| 912 |
-
input_ids: Optional[torch.LongTensor] = None,
|
| 913 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 914 |
-
seq_length: Optional[torch.LongTensor] = None,
|
| 915 |
-
**generate_kwargs,
|
| 916 |
-
) -> torch.LongTensor:
|
| 917 |
-
"""
|
| 918 |
-
[์ต์ข
์์ ] ํ
์คํธ์ ๋ฉํฐ๋ชจ๋ฌ ์์ฒญ์ ํตํฉํ์ฌ ์ฒ๋ฆฌํ๋ generate ํจ์
|
| 919 |
-
"""
|
| 920 |
-
# --- 1. ์
๋ ฅ ์๋ฒ ๋ฉ ์ค๋น ---
|
| 921 |
-
# input_ids๋ ํญ์ ํ์ํฉ๋๋ค (ํ
์คํธ ํ๋กฌํํธ).
|
| 922 |
-
if input_ids is None:
|
| 923 |
-
# ์ด๋ฏธ์ง ์บก์
๋๊ณผ ๊ฐ์ด ํ
์คํธ ํ๋กฌํํธ๊ฐ ์๋ ๊ฒฝ์ฐ๋ฅผ ์ํ ์์ธ ์ฒ๋ฆฌ
|
| 924 |
-
# (ํ์ฌ ์ฌ์ฉ ์ฌ๋ก์๋ ํด๋น๋์ง ์์)
|
| 925 |
-
if pixel_values is not None:
|
| 926 |
-
# ์ด ๊ฒฝ์ฐ, ์์ ํ ํฐ(BOS)๋ง์ผ๋ก input_ids๋ฅผ ์์ฑํด์ผ ํ ์ ์์ต๋๋ค.
|
| 927 |
-
# ์ง๊ธ์ ๊ฐ๋จํ๊ฒ ๋น ํ
์๋ฅผ ์์ฑํฉ๋๋ค.
|
| 928 |
-
input_ids = torch.tensor([[]], dtype=torch.long, device=self.device)
|
| 929 |
-
else:
|
| 930 |
-
raise ValueError("input_ids์ pixel_values๊ฐ ๋ชจ๋ ์์ต๋๋ค.")
|
| 931 |
-
|
| 932 |
-
# ๋ฉํฐ๋ชจ๋ฌ ์์ฒญ์ธ ๊ฒฝ์ฐ, prepare_mm_inputs๋ฅผ ํตํด ํ
์คํธ์ ์ด๋ฏธ์ง๋ฅผ ๊ฒฐํฉํ ์๋ฒ ๋ฉ์ ์์ฑํฉ๋๋ค.
|
| 933 |
-
if pixel_values is not None:
|
| 934 |
-
# ๋ฉํฐ๋ชจ๋ฌ ๊ฒฝ๋ก
|
| 935 |
-
if (
|
| 936 |
-
image_metas is not None
|
| 937 |
-
and image_metas.get("vision_grid_thw") is not None
|
| 938 |
-
and isinstance(image_metas.get("vision_grid_thw"), torch.Tensor)
|
| 939 |
-
):
|
| 940 |
-
image_metas["vision_grid_thw"] = image_metas["vision_grid_thw"].to(input_ids.device)
|
| 941 |
-
|
| 942 |
-
inputs = self.prepare_mm_inputs(
|
| 943 |
-
input_ids=input_ids,
|
| 944 |
-
pixel_values=pixel_values,
|
| 945 |
-
image_metas=image_metas,
|
| 946 |
-
attention_mask=attention_mask,
|
| 947 |
-
)
|
| 948 |
-
# ์ต์ข
์ ์ผ๋ก ์ฌ์ฉํ ์ธ์๋ inputs_embeds์ attention_mask
|
| 949 |
-
final_model_kwargs = {
|
| 950 |
-
"inputs_embeds": inputs.get("inputs_embeds"),
|
| 951 |
-
"attention_mask": inputs.get("attention_mask")
|
| 952 |
-
}
|
| 953 |
-
else:
|
| 954 |
-
# ํ
์คํธ ์ ์ฉ ๊ฒฝ๋ก
|
| 955 |
-
# ์ต์ข
์ ์ผ๋ก ์ฌ์ฉํ ์ธ์๋ input_ids์ attention_mask
|
| 956 |
-
final_model_kwargs = {
|
| 957 |
-
"input_ids": input_ids,
|
| 958 |
-
"attention_mask": attention_mask
|
| 959 |
-
}
|
| 960 |
-
|
| 961 |
-
# --- 2. ์ต์ข
์์ฑ ---
|
| 962 |
-
# ์ค๋น๋ ์ธ์(**final_model_kwargs)์ ์ถ๊ฐ ์์ฑ ์ต์
(**generate_kwargs)์ ํจ๊ป ์ ๋ฌํฉ๋๋ค.
|
| 963 |
-
outputs = self.language_model.generate(
|
| 964 |
-
**final_model_kwargs,
|
| 965 |
-
**generate_kwargs,
|
| 966 |
-
)
|
| 967 |
-
|
| 968 |
-
return outputs
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
def _get_input_dtype(self):
|
| 972 |
-
dtype = next(self.vision_model.parameters()).dtype
|
| 973 |
-
return dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lily_llm_api/models/schemas.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from typing import Optional, List
|
| 6 |
+
|
| 7 |
+
class GenerateRequest(BaseModel):
|
| 8 |
+
prompt: str
|
| 9 |
+
model_id: Optional[str] = None # ๊ธฐ๋ณธ๊ฐ ์ ๊ฑฐ - ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ฌ์ฉ
|
| 10 |
+
max_length: Optional[int] = None
|
| 11 |
+
temperature: Optional[float] = None
|
| 12 |
+
top_p: Optional[float] = None
|
| 13 |
+
do_sample: Optional[bool] = None
|
| 14 |
+
|
| 15 |
+
class GenerateResponse(BaseModel):
|
| 16 |
+
generated_text: str
|
| 17 |
+
processing_time: float
|
| 18 |
+
model_name: str
|
| 19 |
+
image_processed: bool
|
| 20 |
+
|
| 21 |
+
class MultimodalGenerateResponse(BaseModel):
|
| 22 |
+
generated_text: str
|
| 23 |
+
processing_time: float
|
| 24 |
+
model_name: str
|
| 25 |
+
model_id: Optional[str] = None
|
| 26 |
+
image_processed: bool = False
|
| 27 |
+
|
| 28 |
+
class HealthResponse(BaseModel):
|
| 29 |
+
status: str
|
| 30 |
+
model_loaded: bool
|
| 31 |
+
current_model: str
|
| 32 |
+
available_models: List[dict]
|
| 33 |
+
|
| 34 |
+
class DocumentUploadResponse(BaseModel):
|
| 35 |
+
success: bool
|
| 36 |
+
document_id: str
|
| 37 |
+
message: str
|
| 38 |
+
chunks: Optional[int] = None
|
| 39 |
+
latex_count: Optional[int] = None # LaTeX ์์ ๊ฐ์ ํ๋ ์ถ๊ฐ
|
| 40 |
+
error: Optional[str] = None
|
| 41 |
+
auto_response: Optional[str] = None # ์๋ ์๋ต ํ๋ ์ถ๊ฐ
|
| 42 |
+
|
| 43 |
+
class RAGResponse(BaseModel):
|
| 44 |
+
success: bool
|
| 45 |
+
response: str
|
| 46 |
+
context: str
|
| 47 |
+
sources: List[dict]
|
| 48 |
+
search_results: int
|
| 49 |
+
processing_time: float
|
| 50 |
+
|
| 51 |
+
# ์ฌ์ฉ์ ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 52 |
+
class UserResponse(BaseModel):
|
| 53 |
+
success: bool
|
| 54 |
+
user_id: str
|
| 55 |
+
username: Optional[str] = None
|
| 56 |
+
email: Optional[str] = None
|
| 57 |
+
created_at: Optional[str] = None
|
| 58 |
+
error: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
class SessionResponse(BaseModel):
|
| 61 |
+
success: bool
|
| 62 |
+
session_id: str
|
| 63 |
+
session_name: Optional[str] = None
|
| 64 |
+
created_at: Optional[str] = None
|
| 65 |
+
error: Optional[str] = None
|
| 66 |
+
|
| 67 |
+
class ChatMessageResponse(BaseModel):
|
| 68 |
+
success: bool
|
| 69 |
+
message_id: int
|
| 70 |
+
content: str
|
| 71 |
+
message_type: str
|
| 72 |
+
timestamp: str
|
| 73 |
+
error: Optional[str] = None
|
| 74 |
+
|
| 75 |
+
# ์ธ์ฆ ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 76 |
+
class LoginResponse(BaseModel):
|
| 77 |
+
success: bool
|
| 78 |
+
access_token: Optional[str] = None
|
| 79 |
+
refresh_token: Optional[str] = None
|
| 80 |
+
token_type: Optional[str] = None
|
| 81 |
+
user_id: Optional[str] = None
|
| 82 |
+
username: Optional[str] = None
|
| 83 |
+
error: Optional[str] = None
|
| 84 |
+
|
| 85 |
+
class TokenResponse(BaseModel):
|
| 86 |
+
success: bool
|
| 87 |
+
access_token: Optional[str] = None
|
| 88 |
+
token_type: Optional[str] = None
|
| 89 |
+
error: Optional[str] = None
|
| 90 |
+
|
| 91 |
+
# LoRA ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 92 |
+
class LoRAStatusResponse(BaseModel):
|
| 93 |
+
status: str
|
| 94 |
+
lora_available: bool
|
| 95 |
+
current_adapter: Optional[str] = None
|
| 96 |
+
base_model_loaded: bool
|
| 97 |
+
device: str
|
| 98 |
+
message: Optional[str] = None
|
| 99 |
+
|
| 100 |
+
# ์ปจํ
์คํธ ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 101 |
+
class ContextStatusResponse(BaseModel):
|
| 102 |
+
status: str
|
| 103 |
+
context_manager_available: bool
|
| 104 |
+
total_sessions: int
|
| 105 |
+
sessions: dict
|
| 106 |
+
max_tokens: int
|
| 107 |
+
max_turns: int
|
| 108 |
+
strategy: str
|
| 109 |
+
message: Optional[str] = None
|
| 110 |
+
|
| 111 |
+
class ContextHistoryResponse(BaseModel):
|
| 112 |
+
status: str
|
| 113 |
+
session_id: Optional[str] = None
|
| 114 |
+
context: str
|
| 115 |
+
history_length: int
|
| 116 |
+
session_summary: Optional[dict] = None
|
| 117 |
+
all_sessions: Optional[bool] = None
|
| 118 |
+
message: Optional[str] = None
|
| 119 |
+
|
| 120 |
+
class AutoCleanupConfigResponse(BaseModel):
|
| 121 |
+
status: str
|
| 122 |
+
auto_cleanup_config: dict
|
| 123 |
+
message: Optional[str] = None
|
| 124 |
+
|
| 125 |
+
class AutoCleanupConfigRequest(BaseModel):
|
| 126 |
+
enabled: bool = True
|
| 127 |
+
interval_turns: int = 8
|
| 128 |
+
interval_time: int = 300
|
| 129 |
+
strategy: str = "smart"
|
| 130 |
+
|
| 131 |
+
# ๋ฌธ์ ์ฒ๋ฆฌ ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 132 |
+
class DocumentProcessResponse(BaseModel):
|
| 133 |
+
success: bool
|
| 134 |
+
document_id: str
|
| 135 |
+
chunks: int
|
| 136 |
+
processing_time: float
|
| 137 |
+
document_type: str
|
| 138 |
+
page_count: int
|
| 139 |
+
error: Optional[str] = None
|
| 140 |
+
|
| 141 |
+
class RAGQueryRequest(BaseModel):
|
| 142 |
+
query: str
|
| 143 |
+
user_id: str = "anonymous"
|
| 144 |
+
room_id: str = "default"
|
| 145 |
+
max_results: int = 5
|
| 146 |
+
include_sources: bool = True
|
| 147 |
+
|
| 148 |
+
class RAGQueryResponse(BaseModel):
|
| 149 |
+
success: bool
|
| 150 |
+
response: str
|
| 151 |
+
sources: List[dict]
|
| 152 |
+
search_results: int
|
| 153 |
+
processing_time: float
|
| 154 |
+
error: Optional[str] = None
|
| 155 |
+
|
| 156 |
+
# ๋ฉํฐ๋ชจ๋ฌ RAG ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 157 |
+
class MultimodalRAGResponse(BaseModel):
|
| 158 |
+
success: bool
|
| 159 |
+
response: str
|
| 160 |
+
image_processed: bool
|
| 161 |
+
processing_time: float
|
| 162 |
+
error: Optional[str] = None
|
| 163 |
+
|
| 164 |
+
# ์ฑ๋ฅ ๋ชจ๋ํฐ๋ง ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 165 |
+
class PerformanceMetricsResponse(BaseModel):
|
| 166 |
+
status: str
|
| 167 |
+
metrics: dict
|
| 168 |
+
timestamp: str
|
| 169 |
+
error: Optional[str] = None
|
| 170 |
+
|
| 171 |
+
# WebSocket ๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 172 |
+
class WebSocketMessage(BaseModel):
|
| 173 |
+
type: str
|
| 174 |
+
content: str
|
| 175 |
+
user_id: str
|
| 176 |
+
room_id: str
|
| 177 |
+
timestamp: str
|
| 178 |
+
|
| 179 |
+
# Celery ์์
๊ด๋ จ ์๋ต ๋ชจ๋ธ
|
| 180 |
+
class TaskStatusResponse(BaseModel):
|
| 181 |
+
task_id: str
|
| 182 |
+
status: str
|
| 183 |
+
result: Optional[dict] = None
|
| 184 |
+
error: Optional[str] = None
|
lily_llm_api/services/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Services package for Lily LLM API
|
| 3 |
+
"""
|
lily_llm_api/services/generation_service.py
ADDED
|
@@ -0,0 +1,583 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation service for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import io
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
def generate_sync(prompt: str, image_data_list: Optional[List[bytes]], max_length: Optional[int] = None,
|
| 14 |
+
temperature: Optional[float] = None, top_p: Optional[float] = None,
|
| 15 |
+
do_sample: Optional[bool] = None, use_context: bool = True, session_id: str = None,
|
| 16 |
+
user_id: str = "anonymous", room_id: str = "default") -> dict:
|
| 17 |
+
"""[์ต์ ํ] ๋ชจ๋ธ ์์ฑ์ ์ฒ๋ฆฌํ๋ ํตํฉ ๋๊ธฐ ํจ์"""
|
| 18 |
+
try:
|
| 19 |
+
from .model_service import get_current_profile, get_current_model
|
| 20 |
+
from .model_service import tokenizer, processor
|
| 21 |
+
|
| 22 |
+
current_profile = get_current_profile()
|
| 23 |
+
current_model = get_current_model()
|
| 24 |
+
|
| 25 |
+
print(f"๐ [DEBUG] generate_sync ์์ - prompt ๊ธธ์ด: {len(prompt)}")
|
| 26 |
+
print(f"๐ [DEBUG] ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ: {current_profile.display_name if current_profile else 'None'}")
|
| 27 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ํ์
: {type(current_profile) if current_profile else 'None'}")
|
| 28 |
+
|
| 29 |
+
if current_profile is None:
|
| 30 |
+
print("โ [DEBUG] ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์")
|
| 31 |
+
return {"error": "No model loaded"}
|
| 32 |
+
|
| 33 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์ด๋ฆ: {getattr(current_profile, 'model_name', 'Unknown')}")
|
| 34 |
+
print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ์ง์: {getattr(current_profile, 'multimodal', False)}")
|
| 35 |
+
print(f"๐ [DEBUG] ์
๋ ฅ ํ๋กฌํํธ: {prompt}")
|
| 36 |
+
print(f"๐ [DEBUG] ์
๋ ฅ ํ๋กฌํํธ ๊ธธ์ด: {len(prompt)}")
|
| 37 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์กด์ฌ ์ฌ๋ถ: {image_data_list is not None}")
|
| 38 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ๋ฐ์ดํฐ ๊ฐ์: {len(image_data_list) if image_data_list else 0}")
|
| 39 |
+
print(f"๐ [DEBUG] ์ค์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ๊ฐ์: {len([img for img in image_data_list if img]) if image_data_list else 0}")
|
| 40 |
+
|
| 41 |
+
image_processed = False
|
| 42 |
+
all_pixel_values = []
|
| 43 |
+
combined_image_metas = None
|
| 44 |
+
|
| 45 |
+
# --- 1. ์ด๋ฏธ์ง ์ฒ๋ฆฌ (๊ณต์ ๋ฐฉ์) ---
|
| 46 |
+
all_image_data = []
|
| 47 |
+
if image_data_list and len([img for img in image_data_list if img]) > 0:
|
| 48 |
+
all_image_data.extend(image_data_list)
|
| 49 |
+
print(f"๐ [DEBUG] ์ง์ ์ ๋ฌ๋ ์ด๋ฏธ์ง {len(image_data_list)}๊ฐ ์ถ๊ฐ")
|
| 50 |
+
|
| 51 |
+
if all_image_data and len([img for img in all_image_data if img]) > 0 and getattr(current_profile, 'multimodal', False):
|
| 52 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์์ - ์ด ์ด๋ฏธ์ง ๊ฐ์: {len([img for img in all_image_data if img])}")
|
| 53 |
+
|
| 54 |
+
# ๐ ๊ณต์ ๋ฐฉ์: ๊ฐ๋จํ ์ด๋ฏธ์ง ์ฒ๋ฆฌ
|
| 55 |
+
max_images = min(len(all_image_data), 4)
|
| 56 |
+
logger.info(f"๐ผ๏ธ ๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ ์์... (์ด๋ฏธ์ง {max_images}๊ฐ)")
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
metas_list = []
|
| 60 |
+
for idx, image_bytes in enumerate(all_image_data[:max_images]):
|
| 61 |
+
if image_bytes:
|
| 62 |
+
try:
|
| 63 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 64 |
+
# ๐ ๊ณต์ ์ด๋ฏธ์ง ํ๋ก์ธ์ ์ฌ์ฉ
|
| 65 |
+
if processor and hasattr(processor, 'image_processor'):
|
| 66 |
+
processed = processor.image_processor(pil_image)
|
| 67 |
+
all_pixel_values.append(processed["pixel_values"])
|
| 68 |
+
metas_list.append(processed.get("image_meta", {}))
|
| 69 |
+
else:
|
| 70 |
+
logger.warning(f"โ ๏ธ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ์ฐพ์ ์ ์์")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning(f"โ ๏ธ ์ด๋ฏธ์ง {idx} ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 73 |
+
|
| 74 |
+
# ๐ ๋ฉํ๋ฐ์ดํฐ ํตํฉ (๊ณต์ ๋ฐฉ์)
|
| 75 |
+
if metas_list:
|
| 76 |
+
combined_image_metas = {}
|
| 77 |
+
for key in metas_list[0].keys():
|
| 78 |
+
combined_image_metas[key] = [meta[key] for meta in metas_list if key in meta]
|
| 79 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ: {combined_image_metas}")
|
| 80 |
+
else:
|
| 81 |
+
combined_image_metas = {}
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"โ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 84 |
+
combined_image_metas = {}
|
| 85 |
+
|
| 86 |
+
# --- 2. ํ๋กฌํํธ ๊ตฌ์ฑ ---
|
| 87 |
+
print(f"๐ [DEBUG] ํ๋กฌํํธ ๊ตฌ์ฑ ์์")
|
| 88 |
+
|
| 89 |
+
# ์ปจํ
์คํธ ํตํฉ (๋ํ ๊ธฐ๋ก + RAG ๊ฒ์ ๊ฒฐ๊ณผ ํฌํจ) - ๋ชจ๋ธ๋ณ ์ต์ ํ
|
| 90 |
+
context_prompt = ""
|
| 91 |
+
if use_context and session_id:
|
| 92 |
+
try:
|
| 93 |
+
# ์ปจํ
์คํธ ๊ด๋ฆฌ์์์ ์ปจํ
์คํธ ๊ฐ์ ธ์ค๊ธฐ
|
| 94 |
+
try:
|
| 95 |
+
from lily_llm_core.context_manager import context_manager
|
| 96 |
+
context = context_manager.get_context_for_model(
|
| 97 |
+
current_profile.model_name,
|
| 98 |
+
session_id
|
| 99 |
+
)
|
| 100 |
+
if context and len(context.strip()) > 0:
|
| 101 |
+
context_prompt = context + "\n\n"
|
| 102 |
+
print(f"๐ [DEBUG] ๋ํ ์ปจํ
์คํธ ํฌํจ๋จ - ๊ธธ์ด: {len(context_prompt)} (์ธ์
: {session_id})")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"โ ๏ธ [DEBUG] ์ปจํ
์คํธ ๋ก๋ ์คํจ: {e}")
|
| 105 |
+
context_prompt = ""
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"โ ๏ธ [DEBUG] ์ปจํ
์คํธ ๋ก๋ ์คํจ: {e} (์ธ์
: {session_id})")
|
| 109 |
+
context_prompt = ""
|
| 110 |
+
|
| 111 |
+
# formatted_prompt ์ด๊ธฐํ
|
| 112 |
+
formatted_prompt = None
|
| 113 |
+
|
| 114 |
+
# ๐ ๋ฉํฐ๋ชจ๋ฌ ํ๋กฌํํธ ๊ตฌ์ฑ (๊ณต์ ๋ฐฉ์)
|
| 115 |
+
if all_pixel_values and len(all_pixel_values) > 0:
|
| 116 |
+
# ๐ ๊ณต์ Kanana ํ์: Human: <image> ํ
์คํธ
|
| 117 |
+
formatted_prompt = f"Human: <image>{prompt}"
|
| 118 |
+
print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ํ๋กฌํํธ ๊ตฌ์ฑ (๊ณต์ ํ์): {formatted_prompt}")
|
| 119 |
+
image_processed = True
|
| 120 |
+
else:
|
| 121 |
+
image_processed = False
|
| 122 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ์์ - ํ
์คํธ-only ๋ชจ๋")
|
| 123 |
+
|
| 124 |
+
# ํ
์คํธ-only ๋ชจ๋ธ์ฉ ํ๋กฌํํธ ๊ตฌ์ฑ (์ปจํ
์คํธ ํฌํจ)
|
| 125 |
+
if hasattr(current_profile, 'format_prompt'):
|
| 126 |
+
# Polyglot ๋ชจ๋ธ์ผ ๋๋ format_prompt ๋ฉ์๋ ์ฌ์ฉ (์ปจํ
์คํธ ์ง์)
|
| 127 |
+
if "polyglot" in current_profile.model_name.lower():
|
| 128 |
+
# ์ปจํ
์คํธ์ ํ๋กฌํํธ๋ฅผ ํจ๊ป ์ ๋ฌ
|
| 129 |
+
formatted_prompt = current_profile.format_prompt(prompt, context_prompt)
|
| 130 |
+
else:
|
| 131 |
+
# ๋ค๋ฅธ ๋ชจ๋ธ์ ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ
|
| 132 |
+
base_prompt = current_profile.format_prompt(prompt)
|
| 133 |
+
if context_prompt:
|
| 134 |
+
formatted_prompt = context_prompt + base_prompt
|
| 135 |
+
else:
|
| 136 |
+
formatted_prompt = base_prompt
|
| 137 |
+
print(f"๐ [DEBUG] ํ๋กํ format_prompt ์ฌ์ฉ (์ปจํ
์คํธ ํฌํจ): {formatted_prompt}")
|
| 138 |
+
else:
|
| 139 |
+
# ๊ธฐ๋ณธ ํ๋กฌํํธ (fallback) - ์ปจํ
์คํธ ํฌํจ
|
| 140 |
+
if "polyglot" in current_profile.model_name.lower():
|
| 141 |
+
base_prompt = f"### ์ฌ์ฉ์:\n{prompt}\n\n### ์ฑ๋ด:\n"
|
| 142 |
+
else:
|
| 143 |
+
base_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
| 144 |
+
|
| 145 |
+
if context_prompt:
|
| 146 |
+
formatted_prompt = context_prompt + base_prompt
|
| 147 |
+
else:
|
| 148 |
+
formatted_prompt = base_prompt
|
| 149 |
+
print(f"๐ [DEBUG] ๊ธฐ๋ณธ ํ๋กฌํํธ ์ฌ์ฉ (์ปจํ
์คํธ ํฌํจ): {formatted_prompt}")
|
| 150 |
+
|
| 151 |
+
print(f"๐ [DEBUG] ํ๋กฌํํธ ๊ตฌ์ฑ ์๋ฃ - ๊ธธ์ด: {len(formatted_prompt) if formatted_prompt else 0}")
|
| 152 |
+
print(f"๐ [DEBUG] ์ต์ข
ํ๋กฌํํธ: {formatted_prompt}")
|
| 153 |
+
|
| 154 |
+
# --- 3. ํ ํฌ๋์ด์ง ---
|
| 155 |
+
print(f"๐ [DEBUG] ํ ํฌ๋์ด์ง ์์")
|
| 156 |
+
t_tok_start = time.time()
|
| 157 |
+
|
| 158 |
+
if not all_image_data or len([img for img in all_image_data if img]) == 0:
|
| 159 |
+
# ํ
์คํธ-only ๊ณ ์ ๊ฒฝ๋ก (๋ ๋น ๋ฆ)
|
| 160 |
+
print(f"๐ [DEBUG] ํ
์คํธ-only ํ ํฌ๋์ด์ง ๊ฒฝ๋ก")
|
| 161 |
+
print(f"๐ [DEBUG] ์ฌ์ฉํ ํ๋กฌํํธ: {formatted_prompt}")
|
| 162 |
+
|
| 163 |
+
inputs = tokenizer(
|
| 164 |
+
formatted_prompt,
|
| 165 |
+
return_tensors="pt",
|
| 166 |
+
padding=True,
|
| 167 |
+
truncation=True,
|
| 168 |
+
max_length=2048,
|
| 169 |
+
)
|
| 170 |
+
if 'token_type_ids' in inputs:
|
| 171 |
+
del inputs['token_type_ids']
|
| 172 |
+
print(f"๐ [DEBUG] token_type_ids ์ ๊ฑฐ๋จ")
|
| 173 |
+
|
| 174 |
+
input_ids = inputs['input_ids']
|
| 175 |
+
attention_mask = inputs['attention_mask']
|
| 176 |
+
print(f"๐ [DEBUG] ํ ํฌ๋์ด์ ์ถ๋ ฅ: {list(inputs.keys())}")
|
| 177 |
+
else:
|
| 178 |
+
# ๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ
|
| 179 |
+
print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ํ ํฌ๋์ด์ง ๊ฒฝ๋ก")
|
| 180 |
+
|
| 181 |
+
if hasattr(tokenizer, 'encode_prompt'):
|
| 182 |
+
print(f"๐ [DEBUG] encode_prompt ๋ฉ์๋ ์ฌ์ฉ")
|
| 183 |
+
|
| 184 |
+
# ์์ ํ ๋ฉํ๋ฐ์ดํฐ ์์ฑ
|
| 185 |
+
safe_image_meta = {
|
| 186 |
+
'image_token_thw': [[1, 1, 1]] * len(all_pixel_values),
|
| 187 |
+
'vision_grid_thw': [[1, 1, 1]] * len(all_pixel_values)
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
try:
|
| 191 |
+
inputs = tokenizer.encode_prompt(
|
| 192 |
+
prompt=formatted_prompt,
|
| 193 |
+
max_length=2048,
|
| 194 |
+
image_meta=safe_image_meta
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if 'seq_length' in inputs:
|
| 198 |
+
del inputs['seq_length']
|
| 199 |
+
|
| 200 |
+
input_ids = inputs['input_ids']
|
| 201 |
+
attention_mask = inputs['attention_mask']
|
| 202 |
+
|
| 203 |
+
# ํํ์ธ ๊ฒฝ์ฐ ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ
|
| 204 |
+
if isinstance(input_ids, tuple):
|
| 205 |
+
input_ids = input_ids[0]
|
| 206 |
+
if isinstance(attention_mask, tuple):
|
| 207 |
+
attention_mask = attention_mask[0]
|
| 208 |
+
|
| 209 |
+
except Exception as e:
|
| 210 |
+
print(f"โ [DEBUG] encode_prompt ์คํจ: {e}, ํด๋ฐฑ ์ฌ์ฉ")
|
| 211 |
+
# ํด๋ฐฑ: ๊ธฐ๋ณธ ํ ํฌ๋์ด์ ์ฌ์ฉ
|
| 212 |
+
inputs = tokenizer(
|
| 213 |
+
formatted_prompt,
|
| 214 |
+
return_tensors="pt",
|
| 215 |
+
padding=True,
|
| 216 |
+
truncation=True,
|
| 217 |
+
max_length=2048,
|
| 218 |
+
)
|
| 219 |
+
if 'token_type_ids' in inputs:
|
| 220 |
+
del inputs['token_type_ids']
|
| 221 |
+
input_ids = inputs['input_ids']
|
| 222 |
+
attention_mask = inputs['attention_mask']
|
| 223 |
+
else:
|
| 224 |
+
# ์์ ํด๋ฐฑ
|
| 225 |
+
print(f"๐ [DEBUG] ๊ธฐ๋ณธ ํ ํฌ๋์ด์ ์ฌ์ฉ (ํด๋ฐฑ)")
|
| 226 |
+
inputs = tokenizer(
|
| 227 |
+
formatted_prompt,
|
| 228 |
+
return_tensors="pt",
|
| 229 |
+
padding=True,
|
| 230 |
+
truncation=True,
|
| 231 |
+
max_length=2048,
|
| 232 |
+
)
|
| 233 |
+
if 'token_type_ids' in inputs:
|
| 234 |
+
del inputs['token_type_ids']
|
| 235 |
+
input_ids = inputs['input_ids']
|
| 236 |
+
attention_mask = inputs['attention_mask']
|
| 237 |
+
|
| 238 |
+
t_tok_end = time.time()
|
| 239 |
+
print(f"๐ [DEBUG] ํ ํฌ๋์ด์ง ์๋ฃ - ์์์๊ฐ: {t_tok_end - t_tok_start:.3f}์ด")
|
| 240 |
+
|
| 241 |
+
# ๐ input_ids ์์ ํ๊ฒ ์ฒ๋ฆฌ
|
| 242 |
+
if isinstance(input_ids, tuple):
|
| 243 |
+
print(f"๐ [DEBUG] input_ids๊ฐ ํํ์: {len(input_ids)}๊ฐ ์์")
|
| 244 |
+
input_ids = input_ids[0] # ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ
|
| 245 |
+
print(f"๐ [DEBUG] input_ids ํํ์์ ์ฒซ ๋ฒ์งธ ์์ ์ถ์ถ: {input_ids.shape}")
|
| 246 |
+
|
| 247 |
+
# ๐ 1์ฐจ์ ํ
์๋ฅผ 2์ฐจ์์ผ๋ก reshape
|
| 248 |
+
if len(input_ids.shape) == 1:
|
| 249 |
+
print(f"๐ [DEBUG] 1์ฐจ์ ํ
์๋ฅผ 2์ฐจ์์ผ๋ก reshape: {input_ids.shape} -> (1, {input_ids.shape[0]})")
|
| 250 |
+
input_ids = input_ids.unsqueeze(0) # (seq_len,) -> (1, seq_len)
|
| 251 |
+
|
| 252 |
+
# ๐ attention_mask๋ ๋์ผํ๊ฒ ์ฒ๋ฆฌ
|
| 253 |
+
if len(attention_mask.shape) == 1:
|
| 254 |
+
print(f"๐ [DEBUG] attention_mask 1์ฐจ์์ 2์ฐจ์์ผ๋ก reshape: {attention_mask.shape} -> (1, {attention_mask.shape[0]})")
|
| 255 |
+
attention_mask = attention_mask.unsqueeze(0) # (seq_len,) -> (1, seq_len)
|
| 256 |
+
|
| 257 |
+
print(f"๐ [DEBUG] ์ต์ข
input_ids shape: {input_ids.shape}")
|
| 258 |
+
print(f"๐ [DEBUG] ์
๋ ฅ ํ ํฐ ์: {input_ids.shape[1]}")
|
| 259 |
+
|
| 260 |
+
# --- 4. ์์ฑ ์ค์ ---
|
| 261 |
+
print(f"๐ [DEBUG] ์์ฑ ์ค์ ๊ตฌ์ฑ ์์")
|
| 262 |
+
gen_config = current_profile.get_generation_config()
|
| 263 |
+
|
| 264 |
+
# config ํ์ผ์ ๋ช
์๋ eos, pad, bos ํ ํฐ id ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ฑ์ฐ๊ธฐ
|
| 265 |
+
if 'eos_token_id' not in gen_config or gen_config['eos_token_id'] is None:
|
| 266 |
+
gen_config['eos_token_id'] = tokenizer.eos_token_id
|
| 267 |
+
|
| 268 |
+
if 'pad_token_id' not in gen_config or gen_config['pad_token_id'] is None:
|
| 269 |
+
gen_config['pad_token_id'] = tokenizer.pad_token_id or tokenizer.eos_token_id
|
| 270 |
+
|
| 271 |
+
# ํ์ํ ๊ฒฝ์ฐ bos_token_id ๋ ์ค์ (generate ํจ์์ ๋ฐ๋ผ ๋ค๋ฆ)
|
| 272 |
+
if 'bos_token_id' not in gen_config and hasattr(tokenizer, 'bos_token_id'):
|
| 273 |
+
gen_config['bos_token_id'] = tokenizer.bos_token_id
|
| 274 |
+
|
| 275 |
+
# max_new_tokens, temperature ๋ฑ API ์ธ์ ๋ฐ์์ ๋ฎ์ด์ฐ๊ธฐ
|
| 276 |
+
if max_length is not None:
|
| 277 |
+
gen_config['max_new_tokens'] = max_length
|
| 278 |
+
|
| 279 |
+
if temperature is not None:
|
| 280 |
+
gen_config['temperature'] = temperature
|
| 281 |
+
|
| 282 |
+
if top_p is not None:
|
| 283 |
+
gen_config['top_p'] = top_p
|
| 284 |
+
|
| 285 |
+
if do_sample is not None:
|
| 286 |
+
gen_config['do_sample'] = do_sample
|
| 287 |
+
|
| 288 |
+
print(f"๐ [DEBUG] ์์ฑ ์ค์ : {gen_config}")
|
| 289 |
+
|
| 290 |
+
# --- 5. ์ค์ ์ถ๋ก ์คํ ---
|
| 291 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์ถ๋ก ์์")
|
| 292 |
+
t_gen_start = time.time()
|
| 293 |
+
|
| 294 |
+
try:
|
| 295 |
+
# ๋ชจ๋ธ ์ํ ํ์ธ
|
| 296 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ๋๋ฐ์ด์ค: {current_model.device}")
|
| 297 |
+
print(f"๐ [DEBUG] ์
๋ ฅ ํ
์ ๋๋ฐ์ด์ค: {input_ids.device}")
|
| 298 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ํ์
: {type(current_model)}")
|
| 299 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์ํ: {'eval' if current_model.training == False else 'training'}")
|
| 300 |
+
print(f"๐ [DEBUG] ์
๋ ฅ ํ
์ shape: {input_ids.shape}")
|
| 301 |
+
print(f"๐ [DEBUG] attention_mask shape: {attention_mask.shape}")
|
| 302 |
+
print(f"๏ฟฝ๏ฟฝ [DEBUG] all_pixel_values ์กด์ฌ ์ฌ๋ถ: {all_pixel_values is not None}")
|
| 303 |
+
print(f"๐ [DEBUG] all_pixel_values ๊ธธ์ด: {len(all_pixel_values) if all_pixel_values else 0}")
|
| 304 |
+
|
| 305 |
+
# ์
๋ ฅ ํ
์๋ฅผ ๋ชจ๋ธ ๋๋ฐ์ด์ค๋ก ์ด๋
|
| 306 |
+
if input_ids.device != current_model.device:
|
| 307 |
+
print(f"๐ [DEBUG] ์
๋ ฅ ํ
์๋ฅผ ๋ชจ๋ธ ๋๋ฐ์ด์ค๋ก ์ด๋: {input_ids.device} -> {current_model.device}")
|
| 308 |
+
input_ids = input_ids.to(current_model.device)
|
| 309 |
+
attention_mask = attention_mask.to(current_model.device)
|
| 310 |
+
|
| 311 |
+
# ๐ torch import ๋ฌธ์ ํด๊ฒฐ
|
| 312 |
+
import torch
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
if all_pixel_values and len(all_pixel_values) > 0:
|
| 315 |
+
# ๋ฉํฐ๋ชจ๋ฌ: ์ด๋ฏธ์ง์ ํ
์คํธ ํจ๊ป ์ฒ๋ฆฌ
|
| 316 |
+
print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ์ถ๋ก ์คํ")
|
| 317 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ
์ ๊ฐ์: {len(all_pixel_values)}")
|
| 318 |
+
|
| 319 |
+
# ์ด๋ฏธ์ง ํ
์๋ ๋๋ฐ์ด์ค ํ์ธ
|
| 320 |
+
pixel_values = torch.cat(all_pixel_values, dim=0)
|
| 321 |
+
print(f"๐ [DEBUG] ๊ฒฐํฉ๋ ์ด๋ฏธ์ง ํ
์ shape: {pixel_values.shape}")
|
| 322 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ
์ dtype: {pixel_values.dtype}")
|
| 323 |
+
|
| 324 |
+
# ๐ ๋ชจ๋ธ๊ณผ ๋์ผํ dtype์ผ๋ก ๋ณํ (์ฑ๋ฅ ์ต์ ํ)
|
| 325 |
+
if hasattr(current_model, 'dtype'):
|
| 326 |
+
target_dtype = current_model.dtype
|
| 327 |
+
if pixel_values.dtype != target_dtype:
|
| 328 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ
์ dtype ๋ณํ: {pixel_values.dtype} -> {target_dtype}")
|
| 329 |
+
pixel_values = pixel_values.to(dtype=target_dtype)
|
| 330 |
+
else:
|
| 331 |
+
# ๐ ๋ชจ๋ธ dtype์ ์ ์ ์๋ ๊ฒฝ์ฐ bfloat16 ์ฌ์ฉ (Kanana ๋ชจ๋ธ ๊ธฐ๋ณธ๊ฐ)
|
| 332 |
+
target_dtype = torch.bfloat16
|
| 333 |
+
if pixel_values.dtype != target_dtype:
|
| 334 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ
์ dtype ๋ณํ: {pixel_values.dtype} -> {target_dtype}")
|
| 335 |
+
pixel_values = pixel_values.to(dtype=target_dtype)
|
| 336 |
+
|
| 337 |
+
if pixel_values.device != current_model.device:
|
| 338 |
+
print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ
์๋ฅผ ๋ชจ๋ธ ๋๋ฐ์ด์ค๋ก ์ด๋: {pixel_values.device} -> {current_model.device}")
|
| 339 |
+
pixel_values = pixel_values.to(current_model.device)
|
| 340 |
+
|
| 341 |
+
print(f"๐ [DEBUG] ์ต์ข
์ด๋ฏธ์ง ํ
์ ๋๋ฐ์ด์ค: {pixel_values.device}")
|
| 342 |
+
print(f"๐ [DEBUG] ์ต์ข
์ด๋ฏธ์ง ํ
์ dtype: {pixel_values.dtype}")
|
| 343 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์ - ๋ฉํฐ๋ชจ๋ฌ")
|
| 344 |
+
|
| 345 |
+
# LoRA ์ด๋ํฐ๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ์ธ์ง ํ์ธ
|
| 346 |
+
try:
|
| 347 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 348 |
+
if lora_manager and hasattr(lora_manager, 'current_adapter_name') and lora_manager.current_adapter_name:
|
| 349 |
+
print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์ ์ฉ๋จ (๋ฉํฐ๋ชจ๋ฌ): {lora_manager.current_adapter_name}")
|
| 350 |
+
# LoRA๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ ์ฌ์ฉ
|
| 351 |
+
lora_model = lora_manager.get_model()
|
| 352 |
+
if lora_model:
|
| 353 |
+
print(f"๐ [DEBUG] LoRA ๋ชจ๋ธ๋ก ๋ฉํฐ๋ชจ๋ฌ ์์ฑ ์คํ")
|
| 354 |
+
# ๐ image_metas ํ๋ผ๋ฏธํฐ ์ถ๊ฐ (๊ณต์ ๋ฐฉ์)
|
| 355 |
+
processed_image_metas = {}
|
| 356 |
+
|
| 357 |
+
# ๐ ๊ณต์ ๋ฐฉ์: vision_grid_thw๋ฅผ ํ
์๋ก ๋ณํ
|
| 358 |
+
if 'vision_grid_thw' in combined_image_metas:
|
| 359 |
+
vision_grid = combined_image_metas['vision_grid_thw']
|
| 360 |
+
if isinstance(vision_grid, list):
|
| 361 |
+
# ๐ Kanana ๋ชจ๋ธ ์๊ตฌ์ฌํญ: (T, H, W) ํํ์ 3์ฐจ์ ํ
์
|
| 362 |
+
if len(vision_grid) == 1 and len(vision_grid[0]) == 3:
|
| 363 |
+
# [(1, 34, 52)] -> (1, 34, 52) ํ
์๋ก ๋ณํ
|
| 364 |
+
t, h, w = vision_grid[0]
|
| 365 |
+
# ๐ 3์ฐจ์ ํ
์๋ก ๋ณํ: (1, H, W) ํํ
|
| 366 |
+
processed_image_metas['vision_grid_thw'] = torch.tensor([[t, h, w]], dtype=torch.long)
|
| 367 |
+
print(f"๐ [DEBUG] vision_grid_thw ํ
์ ๋ณํ: {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}")
|
| 368 |
+
else:
|
| 369 |
+
# ๐ ๋ค๋ฅธ ํํ์ ๊ฒฝ์ฐ ์๋ณธ ์ ์ง
|
| 370 |
+
processed_image_metas['vision_grid_thw'] = torch.tensor(vision_grid, dtype=torch.long)
|
| 371 |
+
print(f"๐ [DEBUG] vision_grid_thw ํ
์ ๋ณํ (๊ธฐ๋ณธ): {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}")
|
| 372 |
+
else:
|
| 373 |
+
processed_image_metas['vision_grid_thw'] = vision_grid
|
| 374 |
+
|
| 375 |
+
# ๐ ๋ค๋ฅธ ๋ฉํ๋ฐ์ดํฐ๋ ๊ทธ๋๋ก ์ ์ง
|
| 376 |
+
for key, value in combined_image_metas.items():
|
| 377 |
+
if key != 'vision_grid_thw':
|
| 378 |
+
processed_image_metas[key] = value
|
| 379 |
+
|
| 380 |
+
generate_kwargs = {
|
| 381 |
+
'input_ids': input_ids,
|
| 382 |
+
'attention_mask': attention_mask,
|
| 383 |
+
'pixel_values': pixel_values,
|
| 384 |
+
'image_metas': processed_image_metas, # ๐ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ
|
| 385 |
+
**gen_config
|
| 386 |
+
}
|
| 387 |
+
print(f"๐ [DEBUG] LoRA ๋ชจ๋ธ ์์ฑ ํ๋ผ๋ฏธํฐ: {list(generate_kwargs.keys())}")
|
| 388 |
+
print(f"๐ [DEBUG] ์ฒ๋ฆฌ๋ image_metas: {list(processed_image_metas.keys())}")
|
| 389 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์... (ํ์์์ ์์)")
|
| 390 |
+
|
| 391 |
+
generated_ids = lora_model.generate(**generate_kwargs)
|
| 392 |
+
else:
|
| 393 |
+
print(f"โ ๏ธ [DEBUG] LoRA ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ์ ์์, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ")
|
| 394 |
+
generated_ids = current_model.generate(
|
| 395 |
+
input_ids=input_ids,
|
| 396 |
+
attention_mask=attention_mask,
|
| 397 |
+
pixel_values=pixel_values,
|
| 398 |
+
**gen_config
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์์ (๋ฉํฐ๋ชจ๋ฌ), ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ")
|
| 402 |
+
generated_ids = current_model.generate(
|
| 403 |
+
input_ids=input_ids,
|
| 404 |
+
attention_mask=attention_mask,
|
| 405 |
+
pixel_values=pixel_values,
|
| 406 |
+
**gen_config
|
| 407 |
+
)
|
| 408 |
+
except ImportError:
|
| 409 |
+
print(f"๐ [DEBUG] LoRA ์ง์ ์๋จ, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ")
|
| 410 |
+
generated_ids = current_model.generate(
|
| 411 |
+
input_ids=input_ids,
|
| 412 |
+
attention_mask=attention_mask,
|
| 413 |
+
pixel_values=pixel_values,
|
| 414 |
+
**gen_config
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
else:
|
| 418 |
+
# ํ
์คํธ-only: ๊ธฐ์กด ๋ฐฉ์
|
| 419 |
+
print(f"๐ [DEBUG] ํ
์คํธ-only ์ถ๋ก ์คํ")
|
| 420 |
+
print(f"๐ [DEBUG] ์์ฑ ์ค์ : {gen_config}")
|
| 421 |
+
|
| 422 |
+
# ์ถ๊ฐ ์ฑ๋ฅ ์ต์ ํ ์ค์
|
| 423 |
+
gen_config['use_cache'] = True # ์บ์ ์ฌ์ฉ์ผ๋ก ์๋ ํฅ์
|
| 424 |
+
|
| 425 |
+
# PAD ํ ํฐ ์ค์ - ๋ชจ๋ธ ํ๋กํ ์ค์ ์ฐ์
|
| 426 |
+
if 'pad_token_id' not in gen_config:
|
| 427 |
+
# ํ๋กํ์ ์ค์ ์ด ์์ ๋๋ง ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ
|
| 428 |
+
if tokenizer.pad_token_id is not None:
|
| 429 |
+
gen_config['pad_token_id'] = tokenizer.pad_token_id
|
| 430 |
+
print(f"๐ [DEBUG] PAD ํ ํฐ ์ค์ : ํ ํฌ๋์ด์ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ (ID: {tokenizer.pad_token_id})")
|
| 431 |
+
else:
|
| 432 |
+
gen_config['pad_token_id'] = None
|
| 433 |
+
print(f"๐ [DEBUG] PAD ํ ํฐ ์ค์ : None (ํ ํฌ๋์ด์ ์ PAD ํ ํฐ ์์)")
|
| 434 |
+
|
| 435 |
+
# ํ ํฐ ์ค์ - ํ๋กํ์์ ์ค์ ๋ ๊ฐ ์ฐ์ ์ฌ์ฉ
|
| 436 |
+
if 'eos_token_id' not in gen_config or gen_config['eos_token_id'] is None:
|
| 437 |
+
if tokenizer.eos_token_id is not None:
|
| 438 |
+
gen_config['eos_token_id'] = tokenizer.eos_token_id
|
| 439 |
+
print(f"๐ [DEBUG] EOS ํ ํฐ ์ค์ : {tokenizer.eos_token_id}")
|
| 440 |
+
else:
|
| 441 |
+
gen_config['eos_token_id'] = None
|
| 442 |
+
print(f"๐ [DEBUG] EOS ํ ํฐ ์ค์ : None (์๋ ์ฒ๋ฆฌ)")
|
| 443 |
+
|
| 444 |
+
if 'pad_token_id' not in gen_config or gen_config['pad_token_id'] is None:
|
| 445 |
+
if tokenizer.pad_token_id is not None:
|
| 446 |
+
gen_config['pad_token_id'] = tokenizer.pad_token_id
|
| 447 |
+
else:
|
| 448 |
+
gen_config['pad_token_id'] = None
|
| 449 |
+
|
| 450 |
+
if 'bos_token_id' not in gen_config or gen_config['bos_token_id'] is None:
|
| 451 |
+
if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:
|
| 452 |
+
gen_config['bos_token_id'] = tokenizer.bos_token_id
|
| 453 |
+
else:
|
| 454 |
+
gen_config['bos_token_id'] = None
|
| 455 |
+
|
| 456 |
+
print(f"๐ [DEBUG] ์ต์ข
ํ ํฐ ์ค์ : EOS={gen_config['eos_token_id']}, PAD={gen_config['pad_token_id']}, BOS={gen_config.get('bos_token_id')}")
|
| 457 |
+
|
| 458 |
+
# ์์ฑ ์ค์ ์ต์ข
ํ์ธ
|
| 459 |
+
print(f"๐ [DEBUG] ์ต์ข
์์ฑ ์ค์ : {gen_config}")
|
| 460 |
+
|
| 461 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์ - ํ
์คํธ๋ง")
|
| 462 |
+
print(f"๐ [DEBUG] ์ต์ข
์
๋ ฅ ํ
์ ๋๋ฐ์ด์ค: {input_ids.device}")
|
| 463 |
+
print(f"๐ [DEBUG] ์ต์ข
attention_mask ๋๋ฐ์ด์ค: {attention_mask.device}")
|
| 464 |
+
|
| 465 |
+
# ๋ชจ๋ธ ์์ฑ ์งํ ์ํฉ ๋ชจ๋ํฐ๋ง์ ์ํ ์ฝ๋ฐฑ ์ถ๊ฐ
|
| 466 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์ ์๊ฐ: {time.time()}")
|
| 467 |
+
|
| 468 |
+
# LoRA ์ด๋ํฐ๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ์ธ์ง ํ์ธ
|
| 469 |
+
try:
|
| 470 |
+
from lily_llm_core.lora_manager import lora_manager
|
| 471 |
+
if lora_manager and hasattr(lora_manager, 'current_adapter_name') and lora_manager.current_adapter_name:
|
| 472 |
+
print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์ ์ฉ๋จ: {lora_manager.current_adapter_name}")
|
| 473 |
+
# LoRA๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ ์ฌ์ฉ
|
| 474 |
+
lora_model = lora_manager.get_model()
|
| 475 |
+
if lora_model:
|
| 476 |
+
print(f"๐ [DEBUG] LoRA ๋ชจ๋ธ๋ก ์์ฑ ์คํ")
|
| 477 |
+
# LoRA ๋ชจ๋ธ์ฉ ์
๋ ฅ ์ฒ๋ฆฌ (token_type_ids ์ ๊ฑฐ)
|
| 478 |
+
lora_inputs = {
|
| 479 |
+
'input_ids': input_ids,
|
| 480 |
+
'attention_mask': attention_mask
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
generated_ids = lora_model.generate(
|
| 484 |
+
**lora_inputs,
|
| 485 |
+
**gen_config
|
| 486 |
+
)
|
| 487 |
+
else:
|
| 488 |
+
print(f"โ ๏ธ [DEBUG] LoRA ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ์ ์์, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ")
|
| 489 |
+
generated_ids = current_model.generate(
|
| 490 |
+
input_ids=input_ids,
|
| 491 |
+
attention_mask=attention_mask,
|
| 492 |
+
**gen_config
|
| 493 |
+
)
|
| 494 |
+
else:
|
| 495 |
+
print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์์, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ")
|
| 496 |
+
generated_ids = current_model.generate(
|
| 497 |
+
input_ids=input_ids,
|
| 498 |
+
attention_mask=attention_mask,
|
| 499 |
+
**gen_config
|
| 500 |
+
)
|
| 501 |
+
except ImportError:
|
| 502 |
+
print(f"๐ [DEBUG] LoRA ์ง์ ์๋จ, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ")
|
| 503 |
+
generated_ids = current_model.generate(
|
| 504 |
+
input_ids=input_ids,
|
| 505 |
+
attention_mask=attention_mask,
|
| 506 |
+
**gen_config
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์๋ฃ ์๊ฐ: {time.time()}")
|
| 510 |
+
|
| 511 |
+
t_gen_end = time.time()
|
| 512 |
+
print(f"๐ [DEBUG] ๋ชจ๋ธ ์ถ๋ก ์๋ฃ - ์์์๊ฐ: {t_gen_end - t_gen_start:.3f}์ด")
|
| 513 |
+
print(f"๐ [DEBUG] ์์ฑ๋ ํ ํฐ ์: {generated_ids.shape[1] - input_ids.shape[1]}")
|
| 514 |
+
print(f"๐ [DEBUG] ์ต์ข
generated_ids shape: {generated_ids.shape}")
|
| 515 |
+
print(f"๐ [DEBUG] ์ต์ข
generated_ids ๋๋ฐ์ด์ค: {generated_ids.device}")
|
| 516 |
+
print(f"๐ [DEBUG] ์ต์ข
generated_ids dtype: {generated_ids.dtype}")
|
| 517 |
+
|
| 518 |
+
except Exception as e:
|
| 519 |
+
print(f"โ [DEBUG] ๋ชจ๋ธ ์ถ๋ก ์ค ์๋ฌ ๋ฐ์: {str(e)}")
|
| 520 |
+
print(f"โ [DEBUG] ์๋ฌ ํ์
: {type(e).__name__}")
|
| 521 |
+
print(f"โ [DEBUG] ์๋ฌ ์์ธ: {str(e)}")
|
| 522 |
+
import traceback
|
| 523 |
+
traceback.print_exc()
|
| 524 |
+
return {"error": f"Generation failed: {str(e)}"}
|
| 525 |
+
|
| 526 |
+
# --- 6. ์๋ต ์ถ์ถ ---
|
| 527 |
+
print(f"๐ [DEBUG] ์๋ต ์ถ์ถ ์์")
|
| 528 |
+
t_decode_start = time.time()
|
| 529 |
+
|
| 530 |
+
try:
|
| 531 |
+
# ์์ฑ๋ ํ
์คํธ ๋์ฝ๋ฉ
|
| 532 |
+
full_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 533 |
+
print(f"๐ [DEBUG] ์ ์ฒด ํ
์คํธ ๊ธธ์ด: {len(full_text)}")
|
| 534 |
+
print(f"๐ [DEBUG] ์ ์ฒด ์์ฑ ํ
์คํธ (Raw): \n---\n{full_text}\n---")
|
| 535 |
+
print(f"๐ [DEBUG] ์ฌ์ฉ๋ ํ๋กฌํํธ: {formatted_prompt}")
|
| 536 |
+
|
| 537 |
+
# ํ๋กํ๋ณ ์๋ต ์ถ์ถ (์์ ํ ๋ฐฉ์)
|
| 538 |
+
if hasattr(current_profile, 'extract_response'):
|
| 539 |
+
try:
|
| 540 |
+
response = current_profile.extract_response(full_text, formatted_prompt)
|
| 541 |
+
print(f"๐ [DEBUG] ํ๋กํ extract_response ์ฌ์ฉ ์ฑ๊ณต")
|
| 542 |
+
except Exception as extract_error:
|
| 543 |
+
print(f"โ ๏ธ [DEBUG] ํ๋กํ extract_response ์คํจ: {extract_error}")
|
| 544 |
+
# ํด๋ฐฑ: ๊ธฐ๋ณธ ์๋ต ์ถ์ถ
|
| 545 |
+
response = full_text.replace(formatted_prompt, "").strip() if formatted_prompt else full_text
|
| 546 |
+
print(f"๐ [DEBUG] ๊ธฐ๋ณธ ์๋ต ์ถ์ถ ์ฌ์ฉ (ํด๋ฐฑ)")
|
| 547 |
+
else:
|
| 548 |
+
# ๊ธฐ๋ณธ ์๋ต ์ถ์ถ
|
| 549 |
+
response = full_text.replace(formatted_prompt, "").strip() if formatted_prompt else full_text
|
| 550 |
+
print(f"๐ [DEBUG] ๊ธฐ๋ณธ ์๋ต ์ถ์ถ ์ฌ์ฉ")
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
print(f"๐ [DEBUG] ์ถ์ถ๋ ์๋ต ๊ธธ์ด: {len(response)}")
|
| 555 |
+
print(f"๐ [DEBUG] ์ต์ข
์๋ต: {response}")
|
| 556 |
+
|
| 557 |
+
t_decode_end = time.time()
|
| 558 |
+
print(f"๐ [DEBUG] ์๋ต ์ถ์ถ ์๋ฃ - ์์์๊ฐ: {t_decode_end - t_decode_start:.3f}์ด")
|
| 559 |
+
|
| 560 |
+
except Exception as e:
|
| 561 |
+
print(f"โ [DEBUG] ์๋ต ์ถ์ถ ์ค ์๋ฌ ๋ฐ์: {str(e)}")
|
| 562 |
+
import traceback
|
| 563 |
+
traceback.print_exc()
|
| 564 |
+
return {"error": f"Response extraction failed: {str(e)}"}
|
| 565 |
+
|
| 566 |
+
# --- 7. ๊ฒฐ๊ณผ ๋ฐํ ---
|
| 567 |
+
total_time = time.time() - t_tok_start
|
| 568 |
+
print(f"๐ [DEBUG] ์ ์ฒด ์ฒ๋ฆฌ ์๋ฃ - ์ด ์์์๊ฐ: {total_time:.3f}์ด")
|
| 569 |
+
|
| 570 |
+
return {
|
| 571 |
+
"generated_text": response,
|
| 572 |
+
"processing_time": total_time,
|
| 573 |
+
"model_name": current_profile.display_name,
|
| 574 |
+
"image_processed": image_processed,
|
| 575 |
+
"tokens_generated": generated_ids.shape[1] - input_ids.shape[1],
|
| 576 |
+
"total_tokens": generated_ids.shape[1]
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
except Exception as e:
|
| 580 |
+
print(f"โ [DEBUG] generate_sync ์ ์ฒด ์๋ฌ: {str(e)}")
|
| 581 |
+
import traceback
|
| 582 |
+
traceback.print_exc()
|
| 583 |
+
return {"error": str(e)}
|
lily_llm_api/services/model_service.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model service for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import asyncio
|
| 6 |
+
import concurrent.futures
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# ์ ์ญ ๋ณ์๋ค
|
| 12 |
+
current_model = None # ๐ ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ธ์คํด์ค
|
| 13 |
+
current_profile = None # ๐ ํ์ฌ ์ ํ๋ ๋ชจ๋ธ ํ๋กํ
|
| 14 |
+
model_loaded = False # ๐ ๋ชจ๋ธ ๋ก๋ ์ํ
|
| 15 |
+
model = None
|
| 16 |
+
tokenizer = None
|
| 17 |
+
processor = None
|
| 18 |
+
executor = concurrent.futures.ThreadPoolExecutor()
|
| 19 |
+
|
| 20 |
+
def get_current_model():
|
| 21 |
+
"""ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ๋ฐํ"""
|
| 22 |
+
return current_model
|
| 23 |
+
|
| 24 |
+
def get_current_profile():
|
| 25 |
+
"""ํ์ฌ ์ ํ๋ ๋ชจ๋ธ ํ๋กํ ๋ฐํ"""
|
| 26 |
+
return current_profile
|
| 27 |
+
|
| 28 |
+
def is_model_loaded():
|
| 29 |
+
"""๋ชจ๋ธ ๋ก๋ ์ํ ๋ฐํ"""
|
| 30 |
+
return model_loaded
|
| 31 |
+
|
| 32 |
+
async def load_model_async(model_id: str):
|
| 33 |
+
"""๋ชจ๋ธ์ ๋น๋๊ธฐ์ ์ผ๋ก ๋ก๋ฉ"""
|
| 34 |
+
loop = asyncio.get_event_loop()
|
| 35 |
+
await loop.run_in_executor(executor, load_model_sync, model_id)
|
| 36 |
+
|
| 37 |
+
def load_model_sync(model_id: str):
|
| 38 |
+
"""๋ชจ๋ธ ๋ฐ ๊ด๋ จ ํ๋ก์ธ์๋ฅผ ๋๊ธฐ์ ์ผ๋ก ๋ก๋ฉ (์ต์ข
์์ ๋ณธ)"""
|
| 39 |
+
global model, tokenizer, processor, current_profile, current_model, model_loaded
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
if model is not None:
|
| 43 |
+
logger.info("๐๏ธ ๊ธฐ์กด ๋ชจ๋ธ ์ธ๋ก๋ ์ค...")
|
| 44 |
+
del model
|
| 45 |
+
del tokenizer
|
| 46 |
+
del processor
|
| 47 |
+
model, tokenizer, processor = None, None, None
|
| 48 |
+
import gc
|
| 49 |
+
gc.collect()
|
| 50 |
+
logger.info("โ
๊ธฐ์กด ๋ชจ๋ธ ์ธ๋ก๋ ์๋ฃ")
|
| 51 |
+
|
| 52 |
+
logger.info(f"๐ฅ '{model_id}' ๋ชจ๋ธ ๋ก๋ฉ ์์...")
|
| 53 |
+
from ..models import get_model_profile
|
| 54 |
+
current_profile = get_model_profile(model_id)
|
| 55 |
+
|
| 56 |
+
# ์ด์ load_model์ (model, processor)๋ฅผ ๋ฐํํฉ๋๋ค.
|
| 57 |
+
model, processor = current_profile.load_model()
|
| 58 |
+
|
| 59 |
+
# ๐ ์ ์ญ ๋ณ์์ ๋ชจ๋ธ ์ค์ (LoRA์์ ์ฌ์ฉ)
|
| 60 |
+
current_model = model
|
| 61 |
+
|
| 62 |
+
# processor์์ tokenizer๋ฅผ ๊บผ๋ด ์ ์ญ ๋ณ์์ ํ ๋นํฉ๋๋ค.
|
| 63 |
+
if hasattr(processor, 'tokenizer'):
|
| 64 |
+
tokenizer = processor.tokenizer
|
| 65 |
+
else:
|
| 66 |
+
# processor ์์ฒด๊ฐ tokenizer ์ญํ ๋ ํ ์ ์๋ ๊ฒฝ์ฐ
|
| 67 |
+
tokenizer = processor
|
| 68 |
+
|
| 69 |
+
logger.info(f"โ
'{current_profile.display_name}' ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!")
|
| 70 |
+
|
| 71 |
+
# ๐ LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ์๋ ๋ก๋ (๊ณตํต ํจ์ ์ฌ์ฉ)
|
| 72 |
+
try:
|
| 73 |
+
from lily_llm_core.lora_manager import get_lora_manager, lora_manager
|
| 74 |
+
if lora_manager:
|
| 75 |
+
from ..utils.lora_utils import setup_lora_for_model
|
| 76 |
+
setup_lora_for_model(current_profile, lora_manager)
|
| 77 |
+
except ImportError:
|
| 78 |
+
logger.warning("โ ๏ธ LoRA ๊ด๋ฆฌ์ import ์คํจ")
|
| 79 |
+
|
| 80 |
+
model_loaded = True
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"โ load_model_sync ์คํจ: {e}")
|
| 84 |
+
import traceback
|
| 85 |
+
logger.error(f"๐ ์ ์ฒด ์๋ฌ: {traceback.format_exc()}")
|
| 86 |
+
model_loaded = False
|
| 87 |
+
raise
|
| 88 |
+
|
| 89 |
+
def shutdown_executor():
|
| 90 |
+
"""์ค๋ ๋ ํ ์คํ๊ธฐ ์ข
๋ฃ"""
|
| 91 |
+
executor.shutdown(wait=True)
|
lily_llm_api/utils/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities package for Lily LLM API
|
| 3 |
+
"""
|
lily_llm_api/utils/lora_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LoRA utilities for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
def setup_lora_for_model(profile, lora_manager):
|
| 9 |
+
"""๋ชจ๋ธ ํ๋กํ์ ๋ฐ๋ฅธ LoRA ์ค์ (๊ณตํต ํจ์)"""
|
| 10 |
+
if not lora_manager:
|
| 11 |
+
logger.warning("โ ๏ธ LoRA๊ฐ ์ฌ์ฉ ๋ถ๊ฐ๋ฅํ์ฌ ์๋ ์ค์ ๊ฑด๋๋")
|
| 12 |
+
return False
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
logger.info("๐ง LoRA ์๋ ์ค์ ์์...")
|
| 16 |
+
|
| 17 |
+
# ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์
์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
|
| 18 |
+
current_model_path = None
|
| 19 |
+
model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ
|
| 20 |
+
|
| 21 |
+
# ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์
์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ
|
| 22 |
+
if hasattr(profile, 'local_path') and profile.local_path:
|
| 23 |
+
# ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ
|
| 24 |
+
current_model_path = profile.local_path
|
| 25 |
+
# ๐ local_path ์ฌ์ฉ ์์๋ model_type ์ค์ ํ์
|
| 26 |
+
if hasattr(profile, 'model_id') and profile.model_id:
|
| 27 |
+
model_id = profile.model_id
|
| 28 |
+
if model_id == "kanana-1.5-v-3b-instruct":
|
| 29 |
+
model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์
|
| 30 |
+
else:
|
| 31 |
+
model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ
|
| 32 |
+
logger.info(f"๐ ๋ชจ๋ธ ํ๋กํ์์ ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ: {current_model_path}")
|
| 33 |
+
logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์
: {model_type}")
|
| 34 |
+
elif hasattr(profile, 'model_id') and profile.model_id:
|
| 35 |
+
# ๋ชจ๋ธ ID๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฒฝ๋ก ๊ฒฐ์
|
| 36 |
+
model_id = profile.model_id
|
| 37 |
+
logger.info(f"๐ ๋ชจ๋ธ ID ๊ธฐ๋ฐ ๊ฒฝ๋ก ๊ฒฐ์ : {model_id}")
|
| 38 |
+
|
| 39 |
+
# ๐ ํ๊ฒฝ์ ๋ฐ๋ฅธ ๊ฒฝ๋ก ๊ฒฐ์
|
| 40 |
+
if hasattr(profile, 'is_local') and profile.is_local:
|
| 41 |
+
# ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ
|
| 42 |
+
if model_id == "polyglot-ko-1.3b-chat":
|
| 43 |
+
current_model_path = "./lily_llm_core/models/polyglot_ko_1_3b_chat"
|
| 44 |
+
model_type = "causal_lm"
|
| 45 |
+
elif model_id == "kanana-1.5-v-3b-instruct":
|
| 46 |
+
current_model_path = "./lily_llm_core/models/kanana_1_5_v_3b_instruct"
|
| 47 |
+
model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์
|
| 48 |
+
elif model_id == "polyglot-ko-5.8b-chat":
|
| 49 |
+
current_model_path = "./lily_llm_core/models/polyglot_ko_5_8b_chat"
|
| 50 |
+
model_type = "causal_lm"
|
| 51 |
+
else:
|
| 52 |
+
# ๋ฐฐํฌ ํ๊ฒฝ: HF ๋ชจ๋ธ๋ช
์ฌ์ฉ (๋ก์ปฌ ๊ฒฝ๋ก ์์)
|
| 53 |
+
current_model_path = None
|
| 54 |
+
logger.info(f"๐ ๋ฐฐํฌ ํ๊ฒฝ: LoRA ์ค์ ๊ฑด๋๋ (HF ๋ชจ๋ธ)")
|
| 55 |
+
return False
|
| 56 |
+
|
| 57 |
+
logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}")
|
| 58 |
+
logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์
: {model_type}")
|
| 59 |
+
|
| 60 |
+
if not current_model_path:
|
| 61 |
+
logger.warning("โ ๏ธ ํ์ฌ ๋ชจ๋ธ์ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ด LoRA ์๋ ๋ก๋ ๊ฑด๋๋")
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
logger.info(f"๐ LoRA ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}")
|
| 65 |
+
logger.info(f"๐ LoRA ๋ชจ๋ธ ํ์
: {model_type}")
|
| 66 |
+
|
| 67 |
+
# ๐ ์ด๋ฏธ ๋ก๋๋ ๋ฉ์ธ ๋ชจ๋ธ์ LoRA์ ์ง์ ์ ์ฉ (์ค๋ณต ๋ก๋ ๋ฐฉ์ง)
|
| 68 |
+
logger.info("๐ง ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ์ง์ ์ ์ฉ ์์...")
|
| 69 |
+
|
| 70 |
+
# ๐ lora_manager์ ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ ์ค์
|
| 71 |
+
if hasattr(lora_manager, 'base_model') and lora_manager.base_model is None:
|
| 72 |
+
# ์ ์ญ ๋ณ์์์ ๋ฉ์ธ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ
|
| 73 |
+
from ..services.model_service import get_current_model
|
| 74 |
+
current_model = get_current_model()
|
| 75 |
+
if current_model is not None:
|
| 76 |
+
lora_manager.base_model = current_model
|
| 77 |
+
logger.info("โ
๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ๊ด๋ฆฌ์์ ์ค์ ์๋ฃ")
|
| 78 |
+
else:
|
| 79 |
+
logger.warning("โ ๏ธ ๋ฉ์ธ ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ด LoRA ์ค์ ๊ฑด๋๋")
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
# LoRA ์ค์ ์์ฑ
|
| 83 |
+
logger.info("๐ง LoRA ์ค์ ์์ฑ ์์...")
|
| 84 |
+
|
| 85 |
+
# ๐ ๋ชจ๋ธ๋ณ target modules ์ค์
|
| 86 |
+
if model_type == "vision2seq" and "kanana" in profile.model_id:
|
| 87 |
+
# Kanana ๋ชจ๋ธ: Llama ๊ธฐ๋ฐ language model ์ฌ์ฉ (์ฒซ ๋ฒ์งธ ๋ ์ด์ด๋ง ์ฌ์ฉ)
|
| 88 |
+
target_modules = [
|
| 89 |
+
"language_model.model.layers.0.self_attn.q_proj",
|
| 90 |
+
"language_model.model.layers.0.self_attn.k_proj",
|
| 91 |
+
"language_model.model.layers.0.self_attn.v_proj",
|
| 92 |
+
"language_model.model.layers.0.self_attn.o_proj",
|
| 93 |
+
"language_model.model.layers.0.mlp.gate_proj",
|
| 94 |
+
"language_model.model.layers.0.mlp.up_proj",
|
| 95 |
+
"language_model.model.layers.0.mlp.down_proj"
|
| 96 |
+
]
|
| 97 |
+
else:
|
| 98 |
+
# ๊ธฐ์กด ๋ชจ๋ธ๋ค: GPTNeoX ๊ธฐ๋ฐ
|
| 99 |
+
target_modules = ["query_key_value", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"]
|
| 100 |
+
|
| 101 |
+
lora_config = lora_manager.create_lora_config(
|
| 102 |
+
r=16,
|
| 103 |
+
lora_alpha=32,
|
| 104 |
+
lora_dropout=0.1,
|
| 105 |
+
bias="none",
|
| 106 |
+
task_type="CAUSAL_LM" if model_type == "causal_lm" else "VISION_2_SEQ",
|
| 107 |
+
target_modules=target_modules
|
| 108 |
+
)
|
| 109 |
+
logger.info("โ
LoRA ์ค์ ์์ฑ ์๋ฃ")
|
| 110 |
+
|
| 111 |
+
# LoRA ์ด๋ํฐ ์ ์ฉ (๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ ์ง์ )
|
| 112 |
+
logger.info("๐ง LoRA ์ด๋ํฐ ์ ์ฉ ์์...")
|
| 113 |
+
adapter_success = lora_manager.apply_lora_to_model("auto_adapter")
|
| 114 |
+
if adapter_success:
|
| 115 |
+
logger.info("โ
LoRA ์ด๋ํฐ ์ ์ฉ ์๋ฃ: auto_adapter")
|
| 116 |
+
logger.info("๐ LoRA ์๋ ์ค์ ์๋ฃ!")
|
| 117 |
+
return True
|
| 118 |
+
else:
|
| 119 |
+
logger.error("โ LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ")
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"โ LoRA ์๋ ์ค์ ์ค ์ค๋ฅ: {e}")
|
| 124 |
+
return False
|
lily_llm_api/utils/system_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
System utilities for Lily LLM API
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
def configure_cpu_threads():
|
| 11 |
+
"""CPU ์ค๋ ๋ ํ๊ฒฝ ์ต์ ํ (vCPU ์์ ๋ง๊ฒ ์กฐ์ )."""
|
| 12 |
+
print(f"๐ [DEBUG] configure_cpu_threads ์์")
|
| 13 |
+
try:
|
| 14 |
+
# ๊ธฐ๋ณธ๊ฐ: ํ๊ฒฝ๋ณ์ ๋๋ ์์คํ
CPU ์๋ฅผ ์ฌ์ฉํ๋ ๊ณผ๋ํ ์ค๋ ๋ ๋ฐฉ์ง
|
| 15 |
+
env_threads = os.getenv("CPU_THREADS")
|
| 16 |
+
if env_threads is not None:
|
| 17 |
+
threads = max(1, int(env_threads))
|
| 18 |
+
else:
|
| 19 |
+
detected = os.cpu_count() or 2
|
| 20 |
+
# ์ปจํ
์ด๋/์๋ฒ์ vCPU ์๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ๋ ์ํ 16 ์ ์ฉ
|
| 21 |
+
threads = max(1, min(detected, 16))
|
| 22 |
+
|
| 23 |
+
# OpenMP/MKL/numexpr
|
| 24 |
+
os.environ["OMP_NUM_THREADS"] = str(threads)
|
| 25 |
+
os.environ["MKL_NUM_THREADS"] = str(threads)
|
| 26 |
+
os.environ.setdefault("NUMEXPR_NUM_THREADS", str(threads))
|
| 27 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 28 |
+
|
| 29 |
+
# PyTorch ๋ด๋ถ ์ค๋ ๋ ์ค์
|
| 30 |
+
try:
|
| 31 |
+
torch.set_num_threads(threads)
|
| 32 |
+
except Exception:
|
| 33 |
+
pass
|
| 34 |
+
try:
|
| 35 |
+
# ์ฐ์ฐ ๊ฐ ์ค๋ ๋ ํ์ 1~2 ๊ถ์ฅ(์ปจํ
์คํธ ์ค์์นญ ๋น์ฉ ์ ๊ฐ)
|
| 36 |
+
torch.set_num_interop_threads(1 if threads <= 4 else 2)
|
| 37 |
+
except Exception:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
logger.info(f"๐งต CPU thread config -> OMP/MKL/numexpr={threads}, torch_threads={threads}")
|
| 41 |
+
except Exception as e:
|
| 42 |
+
logger.warning(f"โ ๏ธ CPU ์ค๋ ๋ ์ค์ ์คํจ: {e}")
|
| 43 |
+
print(f"๐ [DEBUG] configure_cpu_threads ์ข
๋ฃ")
|
| 44 |
+
|
| 45 |
+
def select_model_interactive():
|
| 46 |
+
"""์ธํฐ๋ํฐ๋ธ ๋ชจ๋ธ ์ ํ"""
|
| 47 |
+
from ..models import list_available_models
|
| 48 |
+
|
| 49 |
+
available_models = list_available_models()
|
| 50 |
+
|
| 51 |
+
print("\n" + "="*60 + "\n๐ค Lily LLM API v2 - ๋ชจ๋ธ ์ ํ\n" + "="*60)
|
| 52 |
+
for i, model_info in enumerate(available_models, 1):
|
| 53 |
+
print(f"{i:2d}. {model_info['name']} ({model_info['model_id']})")
|
| 54 |
+
while True:
|
| 55 |
+
try:
|
| 56 |
+
# choice = input(f"\n๐ ์ฌ์ฉํ ๋ชจ๋ธ ๋ฒํธ๋ฅผ ์ ํํ์ธ์ (1-{len(available_models)}): ")
|
| 57 |
+
# selected_model = available_models[int(choice) - 1]
|
| 58 |
+
selected_model = available_models[1]
|
| 59 |
+
print(f"\nโ
'{selected_model['name']}' ๋ชจ๋ธ์ ์ ํํ์ต๋๋ค.")
|
| 60 |
+
return selected_model['model_id']
|
| 61 |
+
except (ValueError, IndexError):
|
| 62 |
+
print(f"โ 1์์ {len(available_models)} ์ฌ์ด์ ์ซ์๋ฅผ ์
๋ ฅํด์ฃผ์ธ์.")
|
| 63 |
+
except KeyboardInterrupt:
|
| 64 |
+
import sys
|
| 65 |
+
sys.exit("\n\n๐ ํ๋ก๊ทธ๋จ์ ์ข
๋ฃํฉ๋๋ค.")
|
lily_llm_core/document_processor.py
CHANGED
|
@@ -195,7 +195,7 @@ class DocumentProcessor:
|
|
| 195 |
logger.warning(f"โ ๏ธ ์์ ์ถ์ถ ์์ง {formula_ocr_engine} ์ฌ์ฉ ๋ถ๊ฐ, EasyOCR๋ก ๋์ฒด")
|
| 196 |
else:
|
| 197 |
self.formula_extractor_available = False
|
| 198 |
-
|
| 199 |
logger.info(f"๐ DocumentProcessor ์ด๊ธฐํ ์๋ฃ (OCR: {'EasyOCR' if self.ocr_reader else 'None'}, ์์: {formula_ocr_engine})")
|
| 200 |
|
| 201 |
def get_file_type(self, file_path: str) -> str:
|
|
@@ -222,7 +222,7 @@ class DocumentProcessor:
|
|
| 222 |
documents = loader.load()
|
| 223 |
logger.info(f"๐ ๋ฌธ์ ๋ก๋ ์๋ฃ: {len(documents)}๊ฐ ์ฒญํฌ")
|
| 224 |
return documents
|
| 225 |
-
|
| 226 |
except Exception as e:
|
| 227 |
logger.error(f"โ ๋ฌธ์ ๋ก๋ ์คํจ: {e}")
|
| 228 |
return []
|
|
@@ -298,7 +298,7 @@ class DocumentProcessor:
|
|
| 298 |
except Exception as e:
|
| 299 |
logger.error(f"โ ๋ฌธ์ ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 300 |
return []
|
| 301 |
-
|
| 302 |
def _process_pdf_hybrid(self, pdf_path: str) -> List[Document]:
|
| 303 |
"""
|
| 304 |
์ค๋ฌด ์์ค PDF ์ฒ๋ฆฌ (๊ตฌ์กฐ ๋ถ์ + ๊ณต๊ฐ์ ๊ด๊ณ ๋งคํ)
|
|
@@ -427,17 +427,17 @@ class DocumentProcessor:
|
|
| 427 |
# fallback: ํ์ด์ง์์ ์ง์ ์ถ์ถ ์๋
|
| 428 |
pix = page.get_pixmap()
|
| 429 |
continue # ์ด ๊ฒฝ์ฐ๋ ๊ฑด๋๋ฐ๊ธฐ
|
| 430 |
-
|
| 431 |
if pix.n - pix.alpha < 4: # GRAY or RGB
|
| 432 |
if pix.colorspace and pix.colorspace.n > 3:
|
| 433 |
pix = fitz.Pixmap(fitz.csRGB, pix)
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
# ์ด๋ฏธ์ง์ ์์น ์ ๋ณด ์ถ์ถ (์ค์!)
|
| 440 |
-
|
| 441 |
|
| 442 |
if img_rect:
|
| 443 |
bbox = BoundingBox(
|
|
@@ -562,7 +562,7 @@ class DocumentProcessor:
|
|
| 562 |
font_info["fonts"] = list(set(font_info["fonts"]))
|
| 563 |
font_info["sizes"] = list(set(font_info["sizes"]))
|
| 564 |
font_info["flags"] = list(set(font_info["flags"]))
|
| 565 |
-
|
| 566 |
except Exception as e:
|
| 567 |
logger.debug(f"โ ๏ธ ํฐํธ ์ ๋ณด ์ถ์ถ ์คํจ: {e}")
|
| 568 |
|
|
@@ -687,9 +687,9 @@ class DocumentProcessor:
|
|
| 687 |
|
| 688 |
# Document ๊ฐ์ฒด ์์ฑ
|
| 689 |
doc = Document(
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
documents.append(doc)
|
| 694 |
|
| 695 |
logger.info(f"๐ ํ์ด์ง {page.page_num} Document ์์ฑ: "
|
|
@@ -767,7 +767,7 @@ class DocumentProcessor:
|
|
| 767 |
content_parts.append(f"ํ์ด์ง ํฌ๊ธฐ: {page.width:.1f} x {page.height:.1f}")
|
| 768 |
|
| 769 |
return "\n".join(content_parts)
|
| 770 |
-
|
| 771 |
def _is_valid_image(self, img: Image.Image) -> bool:
|
| 772 |
"""์ด๋ฏธ์ง ์ ํจ์ฑ ๊ฒ์ฌ"""
|
| 773 |
try:
|
|
@@ -882,4 +882,4 @@ class DocumentProcessor:
|
|
| 882 |
document_processor = DocumentProcessor(formula_ocr_engine='latexocr')
|
| 883 |
# ํ์์ ๋ค๋ฅธ ์์ง์ผ๋ก ๋ณ๊ฒฝ ๊ฐ๋ฅ:
|
| 884 |
# document_processor = DocumentProcessor(formula_ocr_engine='easyocr') # EasyOCR ์ฌ์ฉ
|
| 885 |
-
# document_processor = DocumentProcessor(formula_ocr_engine='mathpix') # MathPix API ์ฌ์ฉ
|
|
|
|
| 195 |
logger.warning(f"โ ๏ธ ์์ ์ถ์ถ ์์ง {formula_ocr_engine} ์ฌ์ฉ ๋ถ๊ฐ, EasyOCR๋ก ๋์ฒด")
|
| 196 |
else:
|
| 197 |
self.formula_extractor_available = False
|
| 198 |
+
|
| 199 |
logger.info(f"๐ DocumentProcessor ์ด๊ธฐํ ์๋ฃ (OCR: {'EasyOCR' if self.ocr_reader else 'None'}, ์์: {formula_ocr_engine})")
|
| 200 |
|
| 201 |
def get_file_type(self, file_path: str) -> str:
|
|
|
|
| 222 |
documents = loader.load()
|
| 223 |
logger.info(f"๐ ๋ฌธ์ ๋ก๋ ์๋ฃ: {len(documents)}๊ฐ ์ฒญํฌ")
|
| 224 |
return documents
|
| 225 |
+
|
| 226 |
except Exception as e:
|
| 227 |
logger.error(f"โ ๋ฌธ์ ๋ก๋ ์คํจ: {e}")
|
| 228 |
return []
|
|
|
|
| 298 |
except Exception as e:
|
| 299 |
logger.error(f"โ ๋ฌธ์ ์ฒ๋ฆฌ ์คํจ: {e}")
|
| 300 |
return []
|
| 301 |
+
|
| 302 |
def _process_pdf_hybrid(self, pdf_path: str) -> List[Document]:
|
| 303 |
"""
|
| 304 |
์ค๋ฌด ์์ค PDF ์ฒ๋ฆฌ (๊ตฌ์กฐ ๋ถ์ + ๊ณต๊ฐ์ ๊ด๊ณ ๋งคํ)
|
|
|
|
| 427 |
# fallback: ํ์ด์ง์์ ์ง์ ์ถ์ถ ์๋
|
| 428 |
pix = page.get_pixmap()
|
| 429 |
continue # ์ด ๊ฒฝ์ฐ๋ ๊ฑด๋๋ฐ๊ธฐ
|
| 430 |
+
|
| 431 |
if pix.n - pix.alpha < 4: # GRAY or RGB
|
| 432 |
if pix.colorspace and pix.colorspace.n > 3:
|
| 433 |
pix = fitz.Pixmap(fitz.csRGB, pix)
|
| 434 |
+
|
| 435 |
+
img_data = pix.tobytes("png")
|
| 436 |
+
img_pil = Image.open(io.BytesIO(img_data))
|
| 437 |
+
|
| 438 |
+
if self._is_valid_image(img_pil):
|
| 439 |
# ์ด๋ฏธ์ง์ ์์น ์ ๋ณด ์ถ์ถ (์ค์!)
|
| 440 |
+
img_rect = self._get_image_rect(page, xref)
|
| 441 |
|
| 442 |
if img_rect:
|
| 443 |
bbox = BoundingBox(
|
|
|
|
| 562 |
font_info["fonts"] = list(set(font_info["fonts"]))
|
| 563 |
font_info["sizes"] = list(set(font_info["sizes"]))
|
| 564 |
font_info["flags"] = list(set(font_info["flags"]))
|
| 565 |
+
|
| 566 |
except Exception as e:
|
| 567 |
logger.debug(f"โ ๏ธ ํฐํธ ์ ๋ณด ์ถ์ถ ์คํจ: {e}")
|
| 568 |
|
|
|
|
| 687 |
|
| 688 |
# Document ๊ฐ์ฒด ์์ฑ
|
| 689 |
doc = Document(
|
| 690 |
+
page_content=page_content,
|
| 691 |
+
metadata=metadata
|
| 692 |
+
)
|
| 693 |
documents.append(doc)
|
| 694 |
|
| 695 |
logger.info(f"๐ ํ์ด์ง {page.page_num} Document ์์ฑ: "
|
|
|
|
| 767 |
content_parts.append(f"ํ์ด์ง ํฌ๊ธฐ: {page.width:.1f} x {page.height:.1f}")
|
| 768 |
|
| 769 |
return "\n".join(content_parts)
|
| 770 |
+
|
| 771 |
def _is_valid_image(self, img: Image.Image) -> bool:
|
| 772 |
"""์ด๋ฏธ์ง ์ ํจ์ฑ ๊ฒ์ฌ"""
|
| 773 |
try:
|
|
|
|
| 882 |
document_processor = DocumentProcessor(formula_ocr_engine='latexocr')
|
| 883 |
# ํ์์ ๋ค๋ฅธ ์์ง์ผ๋ก ๋ณ๊ฒฝ ๊ฐ๋ฅ:
|
| 884 |
# document_processor = DocumentProcessor(formula_ocr_engine='easyocr') # EasyOCR ์ฌ์ฉ
|
| 885 |
+
# document_processor = DocumentProcessor(formula_ocr_engine='mathpix') # MathPix API ์ฌ์ฉ
|
run_server.py
CHANGED
|
@@ -19,7 +19,7 @@ if __name__ == "__main__":
|
|
| 19 |
|
| 20 |
try:
|
| 21 |
uvicorn.run(
|
| 22 |
-
"lily_llm_api.
|
| 23 |
host="0.0.0.0",
|
| 24 |
port=8001,
|
| 25 |
reload=False,
|
|
|
|
| 19 |
|
| 20 |
try:
|
| 21 |
uvicorn.run(
|
| 22 |
+
"lily_llm_api.app_v2_modular:app",
|
| 23 |
host="0.0.0.0",
|
| 24 |
port=8001,
|
| 25 |
reload=False,
|
run_server_v2.py
CHANGED
|
@@ -11,7 +11,7 @@ import uvicorn
|
|
| 11 |
# ํ๋ก์ ํธ ๋ฃจํธ๋ฅผ Python ๊ฒฝ๋ก์ ์ถ๊ฐ
|
| 12 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
|
| 14 |
-
from lily_llm_api.
|
| 15 |
|
| 16 |
def main():
|
| 17 |
"""๋ฉ์ธ ํจ์"""
|
|
|
|
| 11 |
# ํ๋ก์ ํธ ๋ฃจํธ๋ฅผ Python ๊ฒฝ๋ก์ ์ถ๊ฐ
|
| 12 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
|
| 14 |
+
from lily_llm_api.app import app
|
| 15 |
|
| 16 |
def main():
|
| 17 |
"""๋ฉ์ธ ํจ์"""
|