Spaces:
Sleeping
Sleeping
File size: 10,956 Bytes
9041389 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import json
from typing import Optional
from proto.entity_pb2 import PredictedMovement
from tabulate import tabulate
from domain.entity_d import (
EntityD,
EntityKnowledgeGraphD,
EntityRelationshipD,
RelationshipD,
)
from llm_handler.openai_handler import (
ChatCompletionMessageParam,
ChatModelVersion,
OpenAIHandler,
)
from utils.dates import parse_date
FUZZY_MATCH_ENTITIES_PROMPT = '''
You are an expert in financial analysis. You will be given a two lists of entities. Your task is to output a semantic mapping from the entities in list A to the entities in list B. This means an entity in list A that is semantically similar to an entity in list B should be mapped together. If there is no reasonable semantic match for an entity in list A, output an empty string. Output should be in the format of a JSON object. Ensure the entity keys are in the same order as the input list A.
Input:
List A: ["BofA", "Bank of Amerca Corp" "GDP", "Inflation", "Yen"]
List B: ["Bank of America", "inflation", "Gross Domestic Product", "oil"]
Output:
{
"BofA": "Bank of America",
"Bank of America Corp": "Bank of America",
"GDP": "Gross Domestic Product",
"Inflation": "inflation",
"Yen": ""
}
'''
class EvaluationEngine:
_handler: OpenAIHandler
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
_TEMPERATURE: float = 0.2
def __init__(self,
ground_truth_kg: EntityKnowledgeGraphD,
openai_handler: Optional[OpenAIHandler] = None,
model_version: Optional[ChatModelVersion] = None):
self._handler = openai_handler or OpenAIHandler()
self._model_version = model_version or self._MODEL_VERSION
# setup adjacency list representation of ground truth knowledge graph
self.kg: dict[str, list[EntityRelationshipD]] = {}
for entity_relationship in ground_truth_kg.entity_relationships:
to_entity_name = entity_relationship.to_entity.entity_name
relationships = self.kg.get(to_entity_name, [])
relationships.append(entity_relationship)
self.kg[to_entity_name] = relationships
def _get_thesis_to_gt_entity_map(self, thesis_kg: EntityKnowledgeGraphD) -> dict[str, str]:
thesis_entities = []
for entity_relationship in thesis_kg.entity_relationships:
thesis_entities.append(entity_relationship.to_entity.entity_name)
# LLM call to return out the matched entities
messages: list[ChatCompletionMessageParam] = [
{
"role": "system", "content": FUZZY_MATCH_ENTITIES_PROMPT
}, {
"role": "user",
"content": f"List A: {thesis_entities}\nList B: {list(self.kg.keys())}"
}
]
completion_text = self._handler.get_chat_completion(messages=messages,
model=self._model_version,
temperature=self._TEMPERATURE,
response_format={"type": "json_object"})
thesis_to_gt_entity_mapping: dict[str, str] = json.loads(completion_text)
return thesis_to_gt_entity_mapping
def _get_relationships_matching_timeperiod(
self, gt_kg_to_node: str, relationship: RelationshipD) -> list[EntityRelationshipD]:
matching_relationships = []
thesis_relationship_start = parse_date(relationship.start_date)
thesis_relationship_end = parse_date(relationship.end_date)
for gt_relationship in self.kg[gt_kg_to_node]:
gt_relationship_start = parse_date(gt_relationship.relationship.start_date)
gt_relationship_end = parse_date(gt_relationship.relationship.end_date)
if (gt_relationship_start <= thesis_relationship_start <= gt_relationship_end and \
gt_relationship_start <= thesis_relationship_end <= gt_relationship_end):
# thesis relationship timeframe and gt relationship timeframe overlap
matching_relationships.append(gt_relationship)
return matching_relationships
def evaluate_thesis(
self, thesis_kg: EntityKnowledgeGraphD
) -> list[tuple[EntityRelationshipD, bool, Optional[EntityRelationshipD]]]:
thesis_to_kg_map = self._get_thesis_to_gt_entity_map(thesis_kg)
results = []
for thesis_relationship in thesis_kg.entity_relationships:
thesis_to_node = thesis_relationship.to_entity.entity_name
kg_node = thesis_to_kg_map[thesis_to_node]
if not kg_node: # no matching entity in KG
results.append((thesis_relationship, False, None))
continue
matching_relationships = self._get_relationships_matching_timeperiod(
kg_node, thesis_relationship.relationship)
for entity_relationship in matching_relationships:
if entity_relationship.relationship.predicted_movement == thesis_relationship.relationship.predicted_movement:
results.append((thesis_relationship, True, entity_relationship))
else:
results.append((thesis_relationship, False, entity_relationship))
if len(matching_relationships) == 0:
results.append((thesis_relationship, False, None))
return results
def evaluate_and_display_thesis(self, thesis_kg: EntityKnowledgeGraphD):
results = self.evaluate_thesis(thesis_kg)
int_to_str = {1: "Neutral", 2: 'Increase', 3: 'Decrease'}
headers = ["Thesis Claim", "Supported by KG", "Related KG Relationship"]
table_data = []
for triplet in results:
claim_entity = triplet[0].to_entity.entity_name
claim_movement = int_to_str[triplet[0].relationship.predicted_movement]
claim = f'{claim_entity} {claim_movement}'
if triplet[2]:
evidence = int_to_str[triplet[2].relationship.predicted_movement]
evidence += f' ({triplet[2].from_entity.entity_name}) '
else:
evidence = "No evidence in KG"
table_data.append([claim, triplet[1], evidence])
return tabulate(table_data, tablefmt="html", headers=headers)
if __name__ == '__main__':
# TODO: extract the cases into pytest tests
kg = EntityKnowledgeGraphD(entity_relationships=[
EntityRelationshipD(from_entity=EntityD(entity_id='3', entity_name="analyst A"),
relationship=RelationshipD(
relationship_id='2',
start_date='2021-01-01',
end_date='2024-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="GDP")),
EntityRelationshipD(from_entity=EntityD(entity_id='5', entity_name="analyst B"),
relationship=RelationshipD(
relationship_id='3',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_DECREASE),
to_entity=EntityD(entity_id='1', entity_name="GDP")),
EntityRelationshipD(from_entity=EntityD(entity_id='7', entity_name="analyst C"),
relationship=RelationshipD(
relationship_id='4',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
to_entity=EntityD(entity_id='1', entity_name="GDP")),
EntityRelationshipD(from_entity=EntityD(entity_id='9', entity_name="analyst D"),
relationship=RelationshipD(
relationship_id='5',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
to_entity=EntityD(entity_id='10', entity_name="USD")),
EntityRelationshipD( # out of time range for thesis
from_entity=EntityD(entity_id='9', entity_name="analyst E"),
relationship=RelationshipD(
relationship_id='5',
start_date='2024-01-01',
end_date='2024-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
to_entity=EntityD(entity_id='10', entity_name="USD")),
])
thesis_claims = [
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
relationship=RelationshipD(
relationship_id='1',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="Gross Domestic Product")),
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
relationship=RelationshipD(
relationship_id='1',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="US$")),
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
relationship=RelationshipD(
relationship_id='1',
start_date='2021-01-01',
end_date='2021-12-31',
source_text='',
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
to_entity=EntityD(entity_id='1', entity_name="Yen")),
]
thesis = EntityKnowledgeGraphD(entity_relationships=thesis_claims)
eval_engine = EvaluationEngine(kg)
eval_engine.evaluate_and_display_thesis(thesis)
|