|
import asyncio |
|
import json |
|
import logging |
|
from fastapi import APIRouter, Depends, HTTPException |
|
from httpx import AsyncClient |
|
from jinja2 import Environment |
|
from litellm.router import Router |
|
from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates |
|
from typing import Awaitable, Callable, TypeVar |
|
from schemas import _RefinedSolutionModel, _SearchedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionSearchResponse, SolutionSearchV2Request, TechnologyData |
|
|
|
|
|
router = APIRouter(tags=["solution generation and critique"]) |
|
|
|
|
|
|
|
T = TypeVar("T") |
|
A = TypeVar("A") |
|
|
|
|
|
async def retry_until( |
|
func: Callable[[A], Awaitable[T]], |
|
arg: A, |
|
predicate: Callable[[T], bool], |
|
max_retries: int, |
|
) -> T: |
|
"""Retries the given async function until the passed in validation predicate returns true.""" |
|
last_value = await func(arg) |
|
for _ in range(max_retries): |
|
if predicate(last_value): |
|
return last_value |
|
last_value = await func(arg) |
|
return last_value |
|
|
|
|
|
|
|
@router.post("/search_solutions") |
|
async def search_solutions_if(req: SolutionSearchV2Request, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionSearchResponse: |
|
|
|
async def _search_solution_inner(cat: ReqGroupingCategory): |
|
|
|
fmt_completion = await llm_router.acompletion("gemini-v2", messages=[ |
|
{ |
|
"role": "user", |
|
"content": await prompt_env.get_template("format_requirements.txt").render_async(**{ |
|
"category": cat.model_dump(), |
|
"response_schema": InsightFinderConstraintsList.model_json_schema() |
|
}) |
|
}], response_format=InsightFinderConstraintsList) |
|
|
|
fmt_model = InsightFinderConstraintsList.model_validate_json( |
|
fmt_completion.choices[0].message.content) |
|
|
|
|
|
formatted_constraints = {'constraints': { |
|
cons.title: cons.description for cons in fmt_model.constraints}} |
|
|
|
|
|
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints)) |
|
technologies = TechnologyData.model_validate(technologies_req.json()) |
|
|
|
|
|
|
|
format_solution = await llm_router.acompletion("gemini-v2", messages=[{ |
|
"role": "user", |
|
"content": await prompt_env.get_template("synthesize_solution.txt").render_async(**{ |
|
"category": cat.model_dump(), |
|
"technologies": technologies.model_dump()["technologies"], |
|
"user_constraints": req.user_constraints, |
|
"response_schema": _SearchedSolutionModel.model_json_schema() |
|
})} |
|
], response_format=_SearchedSolutionModel) |
|
|
|
format_solution_model = _SearchedSolutionModel.model_validate_json( |
|
format_solution.choices[0].message.content) |
|
|
|
final_solution = SolutionModel( |
|
Context="", |
|
Requirements=[ |
|
cat.requirements[i].requirement for i in format_solution_model.requirement_ids |
|
], |
|
Problem_Description=format_solution_model.problem_description, |
|
Solution_Description=format_solution_model.solution_description, |
|
References=[], |
|
Category_Id=cat.id, |
|
) |
|
|
|
|
|
|
|
return final_solution |
|
|
|
tasks = await asyncio.gather(*[_search_solution_inner(cat) for cat in req.categories], return_exceptions=True) |
|
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)] |
|
|
|
return SolutionSearchResponse(solutions=final_solutions) |
|
|
|
|
|
@router.post("/criticize_solution", response_model=CritiqueResponse) |
|
async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse: |
|
"""Criticize the challenges, weaknesses and limitations of the provided solutions.""" |
|
|
|
async def __criticize_single(solution: SolutionModel): |
|
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ |
|
"solutions": [solution.model_dump()], |
|
"response_schema": _SolutionCriticismOutput.model_json_schema() |
|
}) |
|
|
|
req_completion = await llm_router.acompletion( |
|
model="gemini-v2", |
|
messages=[{"role": "user", "content": req_prompt}], |
|
response_format=_SolutionCriticismOutput |
|
) |
|
|
|
criticism_out = _SolutionCriticismOutput.model_validate_json( |
|
req_completion.choices[0].message.content |
|
) |
|
|
|
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) |
|
|
|
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) |
|
return CritiqueResponse(critiques=critiques) |
|
|
|
|
|
|
|
|
|
@router.post("/refine_solutions", response_model=SolutionSearchResponse) |
|
async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionSearchResponse: |
|
"""Refines the previously critiqued solutions.""" |
|
|
|
async def __refine_solution(crit: SolutionCriticism): |
|
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ |
|
"solution": crit.solution.model_dump(), |
|
"criticism": crit.criticism, |
|
"response_schema": _RefinedSolutionModel.model_json_schema(), |
|
}) |
|
|
|
req_completion = await llm_router.acompletion(model="gemini-v2", messages=[ |
|
{"role": "user", "content": req_prompt} |
|
], response_format=_RefinedSolutionModel) |
|
|
|
req_model = _RefinedSolutionModel.model_validate_json( |
|
req_completion.choices[0].message.content) |
|
|
|
|
|
refined_solution = crit.solution.model_copy(deep=True) |
|
refined_solution.Problem_Description = req_model.problem_description |
|
refined_solution.Solution_Description = req_model.solution_description |
|
|
|
return refined_solution |
|
|
|
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) |
|
|
|
return SolutionSearchResponse(solutions=refined_solutions) |
|
|
|
|