Spaces:
Sleeping
Sleeping
import asyncio | |
import io | |
from fastapi import APIRouter, Depends, UploadFile, File | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from app.api.dto.kg_query import KGQueryRequest, QueryContext | |
from app.services.predict import PredictService, get_predict_service | |
router = APIRouter() | |
class QueryRequest(BaseModel): | |
question: str | |
context: Optional[List[str]] = None | |
class QueryResponse(BaseModel): | |
answer: str | |
sources: List[str] | |
async def analyze( | |
image: UploadFile = File(None), | |
predict_service: PredictService = Depends(get_predict_service) | |
): | |
# Đọc content một lần và tạo 2 copy riêng biệt | |
image_content = await image.read() | |
# Tạo 2 UploadFile objects riêng biệt từ cùng content | |
image_copy1 = UploadFile( | |
file=io.BytesIO(image_content), | |
filename=image.filename, | |
headers=image.headers | |
) | |
image_copy2 = UploadFile( | |
file=io.BytesIO(image_content), | |
filename=image.filename, | |
headers=image.headers | |
) | |
predicted_label_task = asyncio.create_task(predict_service.predict_image(image_copy1)) | |
caption_task = asyncio.create_task(predict_service.get_caption(image_copy2)) | |
predicted_label, caption = await asyncio.gather(predicted_label_task, caption_task) | |
filtered_labels = [label for label in predicted_label if label.confidence > 0.05] | |
if not filtered_labels and predicted_label: | |
filtered_labels = [max(predicted_label, key=lambda x: x.confidence)] | |
return { | |
"crop_id": filtered_labels[0].crop_id if filtered_labels else None, | |
"predicted_labels": filtered_labels, | |
"caption": caption, | |
# "nodes": nodes, | |
# "final_labels": filtered_labels | |
} | |
async def query_kg( | |
request: KGQueryRequest, | |
predict_service: PredictService = Depends(get_predict_service), | |
): | |
return await predict_service.retrieve_kg(request) | |
async def query_kg_text( | |
request: KGQueryRequest, | |
predict_service: PredictService = Depends(get_predict_service), | |
): | |
return await predict_service.retrieve_kg_text(request) | |
# @router.post("/get-all-nodes") | |
# async def get_all_nodes( | |
# predict_service: PredictService = Depends(get_predict_service), | |
# ): | |
# return await predict_service.get_all_nodes() | |