| | from fastapi import APIRouter, File, Form, UploadFile |
| | from typing import List, Optional |
| |
|
| | try: |
| | import infiagent |
| | from infiagent.services.chat_complete_service import predict |
| | except ImportError: |
| | print("import infiagent failed, please install infiagent by 'pip install .' in the pipeline directory of ADA-Agent") |
| | from ..services.chat_complete_service import predict |
| |
|
| | predict_router = APIRouter() |
| |
|
| |
|
| | @predict_router.post("/predict") |
| | async def chat_predict( |
| | prompt: str = Form(...), |
| | model_name: str = Form(...), |
| | psm: Optional[str] = Form(None), |
| | dc: Optional[str] = Form(None), |
| | temperature: Optional[str] = Form(None), |
| | top_p: Optional[str] = Form(None), |
| | top_k: Optional[str] = Form(None), |
| | files: List[UploadFile] = File(...) |
| | ): |
| | kwargs = {} |
| | if psm: |
| | kwargs['psm'] = psm |
| | if dc: |
| | kwargs['dc'] = dc |
| | if temperature: |
| | kwargs['temperature'] = float(temperature) |
| | if top_p: |
| | kwargs['top_p'] = float(top_p) |
| | if top_k: |
| | kwargs['top_k'] = float(top_k) |
| |
|
| | response = await predict(prompt, model_name, files, **kwargs) |
| |
|
| | return { |
| | "answer": response |
| | } |
| |
|