Upload 26 files
Browse files- healthagent/agents/__init__.py +0 -0
- healthagent/agents/analyzer.py +41 -0
- healthagent/agents/summarization.py +24 -0
- healthagent/agents/threshold.py +19 -0
- healthagent/config/model_config.yaml +17 -0
- healthagent/predictive/__init__.py +0 -0
- healthagent/predictive/analysis.py +8 -0
- healthagent/predictive/custom_model.py +20 -0
- healthagent/slm/__init__.py +0 -0
- healthagent/slm/context_state.py +12 -0
- healthagent/slm/core.py +24 -0
- healthagent/slm/operation.py +73 -0
- healthagent/slm/prompts.py +41 -0
- healthagent/slm/query.py +35 -0
- healthagent/slm/retrieval.py +114 -0
- healthagent/slm/routing.py +25 -0
- healthagent/tools/__init__.py +0 -0
- healthagent/tools/document_nodes.py +17 -0
- healthagent/tools/explainer.py +51 -0
- healthagent/types.py +11 -0
- healthagent/utils/__init__.py +0 -0
- healthagent/utils/data_lib.py +27 -0
- healthagent/utils/logger.py +13 -0
- healthagent/utils/parser.py +20 -0
- healthagent/utils/setup.py +18 -0
- pyproject.toml +48 -0
healthagent/agents/__init__.py
ADDED
File without changes
|
healthagent/agents/analyzer.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from healthagent.predictive.custom_model import ModelAPI
|
4 |
+
from healthagent.tools.explainer import XiaExplainer
|
5 |
+
|
6 |
+
modelAPI = ModelAPI()
|
7 |
+
xiaAPI = XiaExplainer()
|
8 |
+
|
9 |
+
class PredictionInput(BaseModel):
|
10 |
+
"""Schema for prediction input parameters."""
|
11 |
+
driver_temperature: float = Field(..., description="Driver temperature in Celsius")
|
12 |
+
temperature: float = Field(..., description="Temperature in Celsius")
|
13 |
+
speed_rpm: float = Field(..., description="Speed in RPM")
|
14 |
+
current: float = Field(..., description="Current in Amperes")
|
15 |
+
voltage: float = Field(..., description="Voltage in Volts")
|
16 |
+
power: float = Field(..., description="Power in Watts")
|
17 |
+
thruster_id_encoded: float = Field(default=1.0, description="Encoded thruster ID (default=1.0)")
|
18 |
+
|
19 |
+
|
20 |
+
async def prediction_causes_tool(input_data: Dict[str, Any]) -> Dict[str, Any]:
|
21 |
+
"""
|
22 |
+
Useful for identifying the causes of the model predictions.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
input_data (Dict[str, Any]): Raw input values for prediction.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Dict[str, Any]: Mapping of feature -> value.
|
29 |
+
"""
|
30 |
+
|
31 |
+
input_obj = PredictionInput(**input_data)
|
32 |
+
input_dict = input_obj.model_dump()
|
33 |
+
input_dict["thruster_id_encoded"] = "1.0"
|
34 |
+
|
35 |
+
result: List[Dict[str, str]] = xiaAPI.explain_prediction(modelAPI.model, input_dict, True)
|
36 |
+
|
37 |
+
formatted_result: Dict[str, Any] = dict(
|
38 |
+
map(lambda item: (item["feature"], str(item["value"])), result)
|
39 |
+
)
|
40 |
+
|
41 |
+
return formatted_result
|
healthagent/agents/summarization.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.core import Settings
|
2 |
+
from llama_index.core import PromptTemplate
|
3 |
+
from typing import Dict, Any
|
4 |
+
|
5 |
+
async def writing_tool(input_data: Dict[str, Any]) -> str:
|
6 |
+
"""
|
7 |
+
Useful for writing a summary of the given context.
|
8 |
+
|
9 |
+
"""
|
10 |
+
PROMPT = PromptTemplate(
|
11 |
+
template="""
|
12 |
+
Write a detailed information of the given context.
|
13 |
+
The context contains information about causes of prediction, the state of the threshold condtions and how to troubleshoot.
|
14 |
+
|
15 |
+
|
16 |
+
Here are the context:
|
17 |
+
{component}
|
18 |
+
"""
|
19 |
+
)
|
20 |
+
query = PROMPT.format(component=input_data)
|
21 |
+
|
22 |
+
#TODO: Verify Content Loop
|
23 |
+
response = Settings.llm.complete(query)
|
24 |
+
return response
|
healthagent/agents/threshold.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
async def threshold_tool(causes: str) -> str:
|
2 |
+
"""Useful for writing the threshold information of the causes."""
|
3 |
+
"""For each cause state whether the value of the causes are greater than the high threshold values or lesser than the threshold values only."""
|
4 |
+
"""The threshold values given below. Return only the threshold values whose causes were identified """
|
5 |
+
|
6 |
+
threshold = {
|
7 |
+
"voltage": {"low": 20, "high": 28},
|
8 |
+
"current": {"low": 0, "high": 15},
|
9 |
+
"temperature": {"low": 0, "high": 75},
|
10 |
+
"driver_temperature": {"low": 0, "high": 80},
|
11 |
+
"speed": {"low": 1000, "high": 5000},
|
12 |
+
"power": {"low": 200, "high": 1000},
|
13 |
+
}
|
14 |
+
|
15 |
+
|
16 |
+
#TODO: Select only the influencing threshold values
|
17 |
+
|
18 |
+
return str(threshold)
|
19 |
+
|
healthagent/config/model_config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
language_name: qwen2.5:7b
|
3 |
+
embedding: bge-m3
|
4 |
+
predictive_model: healthagent/model/predictive_model_RandomForest.pkl
|
5 |
+
xia_model: healthagent/model/xia_config.pkl
|
6 |
+
|
7 |
+
chunk:
|
8 |
+
size: 512
|
9 |
+
overlap: 64
|
10 |
+
|
11 |
+
url:
|
12 |
+
server: http://api:8080/
|
13 |
+
|
14 |
+
documents:
|
15 |
+
breakpoints: ./healthagent/data/documents/
|
16 |
+
|
17 |
+
storage: ./healthagent/storage/
|
healthagent/predictive/__init__.py
ADDED
File without changes
|
healthagent/predictive/analysis.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
## This function is Deprecated and will be removed in future version
|
3 |
+
def get_status(data=None, model=None) -> str:
|
4 |
+
|
5 |
+
STATUS_LABELS = ['FAULTY', 'NORMAL', 'WARNING']
|
6 |
+
prediction = model.predict(data)[0]
|
7 |
+
|
8 |
+
return STATUS_LABELS[int(prediction.item())]
|
healthagent/predictive/custom_model.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import joblib
|
2 |
+
from healthagent.utils.setup import PREDICTIVE_MODEL
|
3 |
+
from healthagent.utils.data_lib import data_seralize
|
4 |
+
from healthagent.utils.logger import setup_logger
|
5 |
+
|
6 |
+
logger = setup_logger()
|
7 |
+
STATUS_LABEL = ['FAULTY', 'NORMAL', 'WARNING'] #TODO: joblib for the label
|
8 |
+
|
9 |
+
class ModelAPI:
|
10 |
+
def __init__(self):
|
11 |
+
logger.info("Configuring predictive model.")
|
12 |
+
self.model = joblib.load(PREDICTIVE_MODEL)
|
13 |
+
self.STATUS_LABEL = STATUS_LABEL
|
14 |
+
logger.info("Predictive model configured successfully.")
|
15 |
+
|
16 |
+
def predict(self, x):
|
17 |
+
sample_input= data_seralize(x)
|
18 |
+
prediction = self.model.predict(sample_input)[0]
|
19 |
+
value = int(prediction.item())
|
20 |
+
return self.STATUS_LABEL[value]
|
healthagent/slm/__init__.py
ADDED
File without changes
|
healthagent/slm/context_state.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.core.workflow import Context
|
2 |
+
|
3 |
+
async def record_notes(ctx: Context, title:str, content: str) -> str:
|
4 |
+
"""Useful for recording notes on a given results. Your content should be the causes of predictions."""
|
5 |
+
|
6 |
+
#TODO: Perform code struct
|
7 |
+
async with ctx.store.edit_state() as ctx_state:
|
8 |
+
if "_notes" not in ctx_state["state"]:
|
9 |
+
ctx_state["state"]["_notes"] = {}
|
10 |
+
ctx_state["state"]["_notes"][title] = content
|
11 |
+
|
12 |
+
return "Notes recorded."
|
healthagent/slm/core.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.core import Settings
|
2 |
+
from llama_index.llms.ollama import Ollama
|
3 |
+
from llama_index.embeddings.ollama import OllamaEmbedding
|
4 |
+
|
5 |
+
from healthagent.utils.setup import MODEL_NAME, EMBED_MODEL, CHUNK_SIZE, CHUNK_OVERLAP
|
6 |
+
from healthagent.utils.logger import setup_logger
|
7 |
+
|
8 |
+
logger = setup_logger()
|
9 |
+
|
10 |
+
def configure_llm():
|
11 |
+
logger.info("Setting up language model.")
|
12 |
+
Settings.llm = language_model()
|
13 |
+
Settings.embed_model = OllamaEmbedding(model_name=EMBED_MODEL,ollama_additional_kwargs={"mirostat": 0})
|
14 |
+
Settings.chunk_size = CHUNK_SIZE
|
15 |
+
Settings.chunk_overlap = CHUNK_OVERLAP
|
16 |
+
logger.info("Language model configured sucessfully.")
|
17 |
+
|
18 |
+
|
19 |
+
def language_model():
|
20 |
+
return Ollama(model=MODEL_NAME,
|
21 |
+
request_timeout=360,
|
22 |
+
temperature=1.0,
|
23 |
+
context_window=8000
|
24 |
+
)
|
healthagent/slm/operation.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.core.agent.workflow import FunctionAgent, AgentWorkflow
|
2 |
+
from llama_index.core.tools import QueryEngineTool
|
3 |
+
from healthagent.agents.analyzer import prediction_causes_tool
|
4 |
+
from healthagent.agents.threshold import threshold_tool
|
5 |
+
from llama_index.core.workflow import Context
|
6 |
+
from healthagent.slm.context_state import record_notes, write_report
|
7 |
+
from healthagent.agents.summarization import writing_tool
|
8 |
+
from healthagent.slm.retrieval import Retrieval
|
9 |
+
|
10 |
+
|
11 |
+
def causes_agents():
|
12 |
+
return FunctionAgent (
|
13 |
+
name="PredictionCausesAgent",
|
14 |
+
system_prompt=(
|
15 |
+
"You are the PredictionCausesAgent that provides the causes of the model prediction."
|
16 |
+
"Record note of the results with the title Causes of model predictions."
|
17 |
+
),
|
18 |
+
tools=[prediction_causes_tool, record_notes],
|
19 |
+
description="Useful for identifying the causes of the model predictions and record notes on the results.",
|
20 |
+
can_handoff_to=["ThresholdAgent"],
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def threshold_agents():
|
25 |
+
return FunctionAgent(
|
26 |
+
name="ThresholdAgent",
|
27 |
+
system_prompt=(
|
28 |
+
"You are the ThresholdAgent that deterimine if the causes values are low or high "
|
29 |
+
"Record note of the results with the title troubleshoot information."
|
30 |
+
|
31 |
+
),
|
32 |
+
tools=[threshold_tool, record_notes],
|
33 |
+
description="Useful for deterimine if the causes values are low or high and record notes on the results with the title threshold information.",
|
34 |
+
|
35 |
+
can_handoff_to=["TroubleshootDocumentAgent"],
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def write_agent():
|
42 |
+
return FunctionAgent(
|
43 |
+
name="ReportingAgent",
|
44 |
+
description="Useful for summaring all the above information as a a report.",
|
45 |
+
system_prompt=(
|
46 |
+
"You are the WriteAgent that summaries all the above information as a a report "
|
47 |
+
"Write your responses in the write note tool"
|
48 |
+
"Your report should be in a markdown format and use the write_report tool to compile."
|
49 |
+
|
50 |
+
),
|
51 |
+
tools=[writing_tool, write_report],
|
52 |
+
can_handoff_to=[]
|
53 |
+
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
def multi_workflow():
|
58 |
+
causesAgent = causes_agents()
|
59 |
+
thresholdAgent = threshold_agents()
|
60 |
+
|
61 |
+
writeAgent = write_agent()
|
62 |
+
|
63 |
+
|
64 |
+
workflow = AgentWorkflow(
|
65 |
+
agents=[causesAgent, thresholdAgent, writeAgent],
|
66 |
+
root_agent=causesAgent.name,
|
67 |
+
initial_state={
|
68 |
+
"_notes": {},
|
69 |
+
"_content": "Not written yet.",
|
70 |
+
},
|
71 |
+
|
72 |
+
)
|
73 |
+
return workflow
|
healthagent/slm/prompts.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.core import PromptTemplate
|
2 |
+
from typing import List, Dict
|
3 |
+
|
4 |
+
#Deprecated: This function will be removed in future version.
|
5 |
+
def analysis_prompt(data: Dict[str, str], context_info) -> str:
|
6 |
+
|
7 |
+
PROMPT = PromptTemplate(
|
8 |
+
template="""
|
9 |
+
You are a Remote Operating Vehicle (ROV) assistant. Based on the sensor data provided, write a brief summary sentence of the component's health condition based on the status and the causes provided.
|
10 |
+
|
11 |
+
Your response should:
|
12 |
+
Begin the sentence with the component name.
|
13 |
+
Be brief and simple.
|
14 |
+
Indicate the status.
|
15 |
+
Mention any potential problem if sensor readings indicate a possible issue based on the causes.
|
16 |
+
Do not prefix the result with any labels like "Summary:".
|
17 |
+
|
18 |
+
Here are the information:
|
19 |
+
Component Name: {component}
|
20 |
+
Driver Temperature: {driver_temperature} in degrees Celsius
|
21 |
+
Temperature: {temperature} in degrees Celsius
|
22 |
+
Speed: {speed} in RPM
|
23 |
+
Current: {current} in Amperes
|
24 |
+
Power: {power}
|
25 |
+
Voltage: {voltage}
|
26 |
+
Status: {status}
|
27 |
+
Causes: {causes}
|
28 |
+
""")
|
29 |
+
|
30 |
+
return PROMPT.format(
|
31 |
+
component=data["component"],
|
32 |
+
driver_temperature=data["driver_temperature"],
|
33 |
+
temperature=data["temperature"],
|
34 |
+
speed=data["speed_rpm"],
|
35 |
+
current=data["current"],
|
36 |
+
power=data["power"],
|
37 |
+
voltage=data["voltage"],
|
38 |
+
status=data["status"],
|
39 |
+
causes=data["causes"],
|
40 |
+
context=context_info
|
41 |
+
)
|
healthagent/slm/query.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from healthagent.types import ThrusterData
|
2 |
+
from llama_index.core.prompts import RichPromptTemplate
|
3 |
+
|
4 |
+
|
5 |
+
def get_cause_query(thruster_data):
|
6 |
+
template_str = "Identify the cause of the model prediction with the following information: {{ thrusterData }}\n"
|
7 |
+
prompt = RichPromptTemplate(template_str)
|
8 |
+
data = ThrusterData(**thruster_data)
|
9 |
+
|
10 |
+
return prompt.format(thrusterData=data)
|
11 |
+
|
12 |
+
#TODO: Deprecated function. Remove this in next version
|
13 |
+
def xia_query(data):
|
14 |
+
voltage = data["voltage"]
|
15 |
+
current = data["current"]
|
16 |
+
power = data["power"]
|
17 |
+
temperature = data["temperature"]
|
18 |
+
driver_temperature = data["driver_temperature"]
|
19 |
+
speed_rpm = data["speed_rpm"]
|
20 |
+
thruster_id = data["thruster_id_encoded"]
|
21 |
+
|
22 |
+
query = f"Identify and record the causes of the model predictions with driver_temperature of {driver_temperature}, current of {current}, speed_rpm of {speed_rpm}, voltage of {voltage}, temperature of {temperature}, power of {power}."
|
23 |
+
return query
|
24 |
+
|
25 |
+
|
26 |
+
def support_query():
|
27 |
+
threshold = "Identify and record if the model predictions causes values are low or high."
|
28 |
+
troubleshoot = "Write and record the troubleshooting steps for the threshold results."
|
29 |
+
report = "Summaries all the above information as a a report"
|
30 |
+
|
31 |
+
return f"""
|
32 |
+
{threshold}
|
33 |
+
{troubleshoot}
|
34 |
+
{report}
|
35 |
+
"""
|
healthagent/slm/retrieval.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from llama_index.core.node_parser import SemanticSplitterNodeParser, SentenceSplitter
|
3 |
+
from llama_index.core.response_synthesizers import CompactAndRefine
|
4 |
+
from llama_index.core.schema import NodeWithScore
|
5 |
+
from llama_index.postprocessor.rankgpt_rerank import RankGPTRerank
|
6 |
+
from workflows import Context, Workflow, step
|
7 |
+
from workflows.events import Event, StartEvent, StopEvent
|
8 |
+
from llama_index.core import (VectorStoreIndex, SimpleDirectoryReader, StorageContext, load_index_from_storage)
|
9 |
+
from llama_index.core import Settings
|
10 |
+
from healthagent.utils.logger import setup_logger
|
11 |
+
from healthagent.utils.setup import DATA_DIR, STORAGE_DIR
|
12 |
+
|
13 |
+
|
14 |
+
logger = setup_logger()
|
15 |
+
|
16 |
+
|
17 |
+
class RetrieverEvent(Event):
|
18 |
+
"""Result of running retrieval"""
|
19 |
+
|
20 |
+
nodes: list[NodeWithScore]
|
21 |
+
|
22 |
+
|
23 |
+
class RerankEvent(Event):
|
24 |
+
"""Result of running reranking on retrieved nodes"""
|
25 |
+
|
26 |
+
nodes: list[NodeWithScore]
|
27 |
+
|
28 |
+
|
29 |
+
class RAGWorkflow(Workflow):
|
30 |
+
def __init__(self, index: VectorStoreIndex, *args, **kwargs):
|
31 |
+
super().__init__(*args, **kwargs)
|
32 |
+
self.index = index
|
33 |
+
|
34 |
+
@step
|
35 |
+
async def retrieve(self, ctx: Context, ev: StartEvent) -> RetrieverEvent | None:
|
36 |
+
"Entry point for RAG, triggered by a StartEvent with `query`."
|
37 |
+
|
38 |
+
logger.info(f"Retrieving nodes for query: {ev.get('query')}")
|
39 |
+
query = ev.get("query")
|
40 |
+
top_k = ev.get("top_k", 5)
|
41 |
+
top_n = ev.get("top_n", 3)
|
42 |
+
|
43 |
+
if not query:
|
44 |
+
raise ValueError("Query is required!")
|
45 |
+
|
46 |
+
|
47 |
+
await ctx.store.set("query", query)
|
48 |
+
await ctx.store.set("top_k", top_k)
|
49 |
+
await ctx.store.set("top_n", top_n)
|
50 |
+
|
51 |
+
|
52 |
+
retriever = self.index.as_retriever(similarity_top_k=top_k)
|
53 |
+
nodes = await retriever.aretrieve(query)
|
54 |
+
return RetrieverEvent(nodes=nodes)
|
55 |
+
|
56 |
+
@step
|
57 |
+
async def rerank(self, ctx: Context, ev: RetrieverEvent) -> RerankEvent:
|
58 |
+
top_n = await ctx.store.get("top_n")
|
59 |
+
query = await ctx.store.get("query")
|
60 |
+
|
61 |
+
ranker = RankGPTRerank(top_n=top_n, llm=Settings.llm)
|
62 |
+
|
63 |
+
try:
|
64 |
+
new_nodes = ranker.postprocess_nodes(ev.nodes, query_str=query)
|
65 |
+
except Exception:
|
66 |
+
new_nodes = ev.nodes
|
67 |
+
return RerankEvent(nodes=new_nodes)
|
68 |
+
|
69 |
+
@step
|
70 |
+
async def synthesize(self, ctx: Context, ev: RerankEvent) -> StopEvent:
|
71 |
+
"""Return a response using reranked nodes."""
|
72 |
+
|
73 |
+
synthesizer = CompactAndRefine(llm=Settings.llm)
|
74 |
+
query = await ctx.store.get("query", default=None)
|
75 |
+
|
76 |
+
response = await synthesizer.asynthesize(query, nodes=ev.nodes)
|
77 |
+
|
78 |
+
return StopEvent(result=str(response))
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
class Retrieval:
|
83 |
+
|
84 |
+
def __init__(self):
|
85 |
+
if not os.path.exists(STORAGE_DIR):
|
86 |
+
logger.info("Building index.")
|
87 |
+
self.index = self.build_index()
|
88 |
+
else:
|
89 |
+
logger.info("Loading index.")
|
90 |
+
self.index = self.load_index()
|
91 |
+
|
92 |
+
|
93 |
+
#TODO: Change to mass index scheme
|
94 |
+
def build_index(self):
|
95 |
+
documents = SimpleDirectoryReader(DATA_DIR).load_data()
|
96 |
+
index = VectorStoreIndex.from_documents(documents)
|
97 |
+
index.storage_context.persist(STORAGE_DIR)
|
98 |
+
|
99 |
+
return index
|
100 |
+
|
101 |
+
|
102 |
+
#TODO: Change to mass index scheme
|
103 |
+
def load_index(self):
|
104 |
+
storage_context = StorageContext.from_defaults(persist_dir=STORAGE_DIR)
|
105 |
+
return load_index_from_storage(storage_context)
|
106 |
+
|
107 |
+
|
108 |
+
def query_context(self, text: str) -> str:
|
109 |
+
query_engine = self.index.as_query_engine()
|
110 |
+
result = query_engine.query(text)
|
111 |
+
return str(result)
|
112 |
+
|
113 |
+
def workflow(self) -> RAGWorkflow:
|
114 |
+
return RAGWorkflow(index=self.index, timeout=120.0)
|
healthagent/slm/routing.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.graph import StateGraph, MessagesState, START, END
|
2 |
+
|
3 |
+
class RouteDecision:
|
4 |
+
def __init__(self):
|
5 |
+
self.graph = StateGraph(MessagesState)
|
6 |
+
self.graph.add_node("reverse", self._reversal)
|
7 |
+
self.graph.add_node("verify", self._verify)
|
8 |
+
self.graph.add_edge(START, "_reversal")
|
9 |
+
#self.graph.add_conditional_edges("reverse", "verify") #TODO: Improve on the conditional edges
|
10 |
+
self.graph.add_edge("_reversal", END)
|
11 |
+
|
12 |
+
def _compile(self):
|
13 |
+
return self.graph.compile()
|
14 |
+
|
15 |
+
def _reversal(self, state: MessagesState):
|
16 |
+
"""Responsible for reversing the prediction decision"""
|
17 |
+
return state["reverse"] = "Prediction Decision"
|
18 |
+
|
19 |
+
def _verify(self, state: MessagesState):
|
20 |
+
"""Responsible for verifying the prediction from reversal"""
|
21 |
+
return state["_reversal"] = True
|
22 |
+
|
23 |
+
#TODO: Read the docs nodes into routing graphs
|
24 |
+
def _nodes_graph(self):
|
25 |
+
pass
|
healthagent/tools/__init__.py
ADDED
File without changes
|
healthagent/tools/document_nodes.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.utils.setup import BREAKPOINTS
|
2 |
+
from llama_index.core import SimpleDirectoryReader
|
3 |
+
from llama_index.core.node_parser import SentenceSplitter
|
4 |
+
from llama_index.core.storage.docstore import SimpleDocumentStore
|
5 |
+
|
6 |
+
def bake_docs_nodes(file_stream):
|
7 |
+
documents = SimpleDirectoryReader(file_stream).load_data()
|
8 |
+
node_parser = SentenceSplitter(chunk_size=512)
|
9 |
+
base_nodes = node_parser.get_nodes_from_documents(documents)
|
10 |
+
return base_nodes
|
11 |
+
|
12 |
+
|
13 |
+
#TODO: Responsible for serializing and caching documents into node docs.
|
14 |
+
if __name__ == "__main__":
|
15 |
+
docstore = SimpleDocumentStore()
|
16 |
+
nodes = bake_docs_nodes(BREAKPOINTS)
|
17 |
+
docstore.add_documents(nodes)
|
healthagent/tools/explainer.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import joblib
|
4 |
+
from lime.lime_tabular import LimeTabularExplainer
|
5 |
+
from healthagent.utils.setup import XIA_MODEL
|
6 |
+
from healthagent.utils.data_lib import data_seralize
|
7 |
+
from healthagent.utils.parser import parse_feature_contributions
|
8 |
+
from healthagent.utils.logger import setup_logger
|
9 |
+
|
10 |
+
logger = setup_logger()
|
11 |
+
|
12 |
+
class XiaExplainer:
|
13 |
+
def __init__(self):
|
14 |
+
logger.info("Configuring explainer model.")
|
15 |
+
self.config = joblib.load(XIA_MODEL)
|
16 |
+
self.explainer = self.initialize_explainer()
|
17 |
+
logger.info("Explainer model configured sucessfully.")
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_explainer(self):
|
21 |
+
return LimeTabularExplainer(
|
22 |
+
training_data=self.config["X_train"].values,
|
23 |
+
feature_names=self.config["feature_names"],
|
24 |
+
class_names=self.config["class_names"],
|
25 |
+
mode="classification",
|
26 |
+
verbose=False
|
27 |
+
)
|
28 |
+
|
29 |
+
def explain_prediction(self, model, data, parse=True, top_k=2):
|
30 |
+
feature_names=self.config["feature_names"]
|
31 |
+
class_names=self.config["class_names"]
|
32 |
+
sample_input= data_seralize(data)
|
33 |
+
|
34 |
+
probas = model.predict_proba(sample_input)[0]
|
35 |
+
idx = probas.argmax()
|
36 |
+
exp = self.explainer.explain_instance(
|
37 |
+
data_row= sample_input.values[0].astype(np.float16),
|
38 |
+
predict_fn=lambda x: model.predict_proba(pd.DataFrame(x, columns=feature_names)),
|
39 |
+
num_features=len(feature_names)
|
40 |
+
)
|
41 |
+
contributions = exp.as_list()
|
42 |
+
top = sorted(contributions, key=lambda x: abs(x[1]), reverse=True)[:top_k]
|
43 |
+
|
44 |
+
if parse:
|
45 |
+
return parse_feature_contributions(top)
|
46 |
+
|
47 |
+
return {
|
48 |
+
"predicted_class": class_names[idx],
|
49 |
+
"confidence": probas[idx],
|
50 |
+
"top_contributing_features": top
|
51 |
+
}
|
healthagent/types.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
|
3 |
+
class ThrusterData(BaseModel):
|
4 |
+
"""Schema for prediction input parameters."""
|
5 |
+
driver_temperature: float = Field(description="Driver temperature in Celsius")
|
6 |
+
temperature: float = Field(description="Temperature in Celsius")
|
7 |
+
speed_rpm: float = Field(description="Speed in RPM")
|
8 |
+
current: float = Field(description="Current in Amperes")
|
9 |
+
voltage: float = Field(description="Voltage in Volts")
|
10 |
+
power: float = Field(description="Power in Watts")
|
11 |
+
thruster_id_encoded: float = Field(default=1.0, description="Encoded thruster ID (default=1.0)")
|
healthagent/utils/__init__.py
ADDED
File without changes
|
healthagent/utils/data_lib.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import random
|
3 |
+
|
4 |
+
def data_seralize(thruster_data):
|
5 |
+
return pd.DataFrame(data={
|
6 |
+
"voltage": [thruster_data['voltage']],
|
7 |
+
"current": [thruster_data['current']],
|
8 |
+
"power": [thruster_data['power']],
|
9 |
+
"temperature": [thruster_data['temperature']],
|
10 |
+
"driver_temperature": [thruster_data['driver_temperature']],
|
11 |
+
"speed_rpm": [thruster_data['speed_rpm']],
|
12 |
+
"thruster_id_encoded": [thruster_data['thruster_id_encoded']]
|
13 |
+
})
|
14 |
+
|
15 |
+
|
16 |
+
# For testing purpose
|
17 |
+
def fake_data():
|
18 |
+
data = {
|
19 |
+
"voltage": round(random.uniform(0.05, 0.15), 3),
|
20 |
+
"current": round(random.uniform(5.0, 8.0), 3),
|
21 |
+
"power": round(random.uniform(0.5, 1.0), 3),
|
22 |
+
"temperature": round(random.uniform(25.0, 40.0), 3),
|
23 |
+
"driver_temperature": round(random.uniform(28.0, 45.0), 3),
|
24 |
+
"speed_rpm": round(random.uniform(3500, 5000), 3),
|
25 |
+
"thruster_id_encoded": float(random.randint(1, 8))
|
26 |
+
}
|
27 |
+
return data
|
healthagent/utils/logger.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
def setup_logger(name="Custom Logging System") -> logging.Logger:
|
4 |
+
logger = logging.getLogger(name)
|
5 |
+
logger.setLevel(logging.INFO)
|
6 |
+
|
7 |
+
if not logger.handlers:
|
8 |
+
ch = logging.StreamHandler()
|
9 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
10 |
+
ch.setFormatter(formatter)
|
11 |
+
logger.addHandler(ch)
|
12 |
+
|
13 |
+
return logger
|
healthagent/utils/parser.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def parse_feature_contributions(contributions, to_text=True):
|
4 |
+
parsed = []
|
5 |
+
for label, value in contributions:
|
6 |
+
match = re.search(r"<\s*(\w+)\s*<=", label)
|
7 |
+
if match:
|
8 |
+
feature = match.group(1)
|
9 |
+
else:
|
10 |
+
parts = re.match(r"(.+?)\s*(<=|>=|<|>|=)\s*(.+)", label)
|
11 |
+
feature = parts.group(1).strip() if parts else label.strip()
|
12 |
+
parsed.append({
|
13 |
+
"feature": feature,
|
14 |
+
"contribution": f"{'+' if value >= 0 else '-'}{abs(value):.3f}"
|
15 |
+
})
|
16 |
+
if to_text:
|
17 |
+
response = [item['feature'] for index, item in enumerate(parsed)]
|
18 |
+
return ", ".join(response)
|
19 |
+
|
20 |
+
return parsed
|
healthagent/utils/setup.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
|
7 |
+
with open("healthagent/config/model_config.yaml", "r") as file:
|
8 |
+
config = yaml.safe_load(file)
|
9 |
+
|
10 |
+
|
11 |
+
CHUNK_SIZE = config["chunk"]["size"]
|
12 |
+
CHUNK_OVERLAP = config["chunk"]["overlap"]
|
13 |
+
MODEL_NAME = config["model"]["language_name"]
|
14 |
+
EMBED_MODEL = config["model"]["embedding"]
|
15 |
+
PREDICTIVE_MODEL = config["model"]["predictive_model"]
|
16 |
+
XIA_MODEL = config["model"]["xia_model"]
|
17 |
+
DATA_DIR = config["documents"]["breakpoints"]
|
18 |
+
STORAGE_DIR = config["storage"]
|
pyproject.toml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "healthagent"
|
7 |
+
version = "0.1.0"
|
8 |
+
description = "AI Health Agent for the Recon Robot"
|
9 |
+
authors = [
|
10 |
+
{ name = "Henry Asiedu", email = "henry.asiedu@fortressaisolutions.com" },
|
11 |
+
{ name = "Amanda Ofori", email = "amanda.ofori@fortressaisolutions.com" }
|
12 |
+
]
|
13 |
+
readme = "README.md"
|
14 |
+
requires-python = ">=3.9"
|
15 |
+
maintainers = [
|
16 |
+
{ name = "Henry Asiedu", email = "henry.asiedu@fortressaisolutions.com" },
|
17 |
+
{ name = "Amanda Ofori", email = "amanda.ofori@fortressaisolutions.com" }
|
18 |
+
]
|
19 |
+
|
20 |
+
dependencies = [
|
21 |
+
"numpy",
|
22 |
+
"pandas",
|
23 |
+
"matplotlib",
|
24 |
+
"scikit-learn",
|
25 |
+
"seaborn",
|
26 |
+
"llama-index-core==0.13.3",
|
27 |
+
"llama-index",
|
28 |
+
"llama-index-llms-ollama",
|
29 |
+
"llama-index-embeddings-ollama",
|
30 |
+
"lime",
|
31 |
+
"pyyaml",
|
32 |
+
"joblib",
|
33 |
+
"fastapi",
|
34 |
+
"uvicorn",
|
35 |
+
"pydantic",
|
36 |
+
"llama-index-postprocessor-rankgpt-rerank>=0.2.0",
|
37 |
+
"llama-index-llms-mistralai==0.7.1",
|
38 |
+
"llama-index-embeddings-mistralai"
|
39 |
+
]
|
40 |
+
|
41 |
+
[project.optional-dependencies]
|
42 |
+
dev = [
|
43 |
+
"streamlit"
|
44 |
+
]
|
45 |
+
|
46 |
+
[tool.setuptools]
|
47 |
+
|
48 |
+
packages = ["healthagent"]
|