File size: 4,562 Bytes
e97be0e
5ef0f8d
f6bffda
5ef0f8d
f6bffda
 
5ef0f8d
f6bffda
 
5ef0f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e97be0e
 
5ef0f8d
 
 
 
 
f6bffda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
from fastapi import APIRouter, Depends, HTTPException
from jinja2 import Environment
from litellm.router import Router
from dependencies import get_llm_router, get_prompt_templates
from schemas import _ReqGroupingCategory, _ReqGroupingOutput, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse

# Router for requirement processing
router = APIRouter(tags=["requirement processing"])


@router.post("/get_reqs_from_query", response_model=ReqSearchResponse)
def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)):
    """Finds the requirements that adress a given problem description from an extracted list"""

    requirements = req.requirements
    query = req.query

    requirements_text = "\n".join(
        [f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
    resp_ai = llm_router.completion(
        model="gemini-v2",
        messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
        response_format=ReqSearchLLMResponse
    )

    out_llm = ReqSearchLLMResponse.model_validate_json(
        resp_ai.choices[0].message.content).selected

    logging.info(f"Found {len(out_llm)} reqs matching case.")

    if max(out_llm) > len(requirements) - 1:
        raise HTTPException(
            status_code=500, detail="LLM error : Generated a wrong index, please try again.")

    return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])


@router.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> ReqGroupingResponse:
    """Categorize the given service requirements into categories"""

    MAX_ATTEMPTS = 5

    categories: list[_ReqGroupingCategory] = []
    messages = []

    # categorize the requirements using their indices
    req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
        "requirements": [rq.model_dump() for rq in params.requirements],
        "max_n_categories": params.max_n_categories,
        "response_schema": _ReqGroupingOutput.model_json_schema()})

    # add system prompt with requirements
    messages.append({"role": "user", "content": req_prompt})

    # ensure all requirements items are processed
    for attempt in range(MAX_ATTEMPTS):
        req_completion = await llm_router.acompletion(model="gemini-v2", messages=messages, response_format=_ReqGroupingOutput)
        output = _ReqGroupingOutput.model_validate_json(
            req_completion.choices[0].message.content)

        # quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
        valid_ids_universe = set(range(0, len(params.requirements)))
        assigned_ids = {
            req_id for cat in output.categories for req_id in cat.items}

        # keep only non-hallucinated, valid assigned ids
        valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)

        # check for remaining requirements assigned to none of the categories
        unassigned_ids = valid_ids_universe - valid_assigned_ids

        if len(unassigned_ids) == 0:
            categories.extend(output.categories)
            break
        else:
            messages.append(req_completion.choices[0].message)
            messages.append(
                {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})

            if attempt == MAX_ATTEMPTS - 1:
                raise Exception("Failed to classify all requirements")

    # build the final category objects
    # remove the invalid (likely hallucinated) requirement IDs
    final_categories = []
    for idx, cat in enumerate(output.categories):
        final_categories.append(ReqGroupingCategory(
            id=idx,
            title=cat.title,
            requirements=[params.requirements[i]
                          for i in cat.items if i < len(params.requirements)]
        ))

    return ReqGroupingResponse(categories=final_categories)