subashpoudel's picture
updated analytics
a6a0614
import pandas as pd
from src.genai.utils.data_loader import caption_df
from src.genai.utils.models_loader import llm_gpt
from .prompts import details_extract_prompt
from langchain_core.messages import SystemMessage, HumanMessage
from .state import DetailsFormatter
from langsmith import traceable
class DetailsExtractorNode:
def __init__(self, interactions):
self.llm = llm_gpt
self.interactions = interactions
@traceable(name="details extraction")
def run(self):
template = details_extract_prompt()
messages = [SystemMessage(content=template), HumanMessage(content=str(self.interactions))]
response=llm_gpt.with_structured_output(DetailsFormatter).invoke(messages)
return response.model_dump()
class SaveToDB:
def __init__(self, caption_df):
self.df = caption_df.drop(columns=['embeddings'], errors='ignore')
def _prepare_values(self, business_details):
"""Extract lowercase string values from business_details dict."""
all_values = set()
for v in business_details.values():
if isinstance(v, str):
all_values.add(v.lower())
elif isinstance(v, list):
all_values.update(map(str.lower, map(str, v)))
return all_values
def _row_matches(self, row, all_values):
"""Check if any value in all_values exists in the row."""
return any(
str(cell).lower().find(val) != -1
for cell in row
for val in all_values
)
def save_to_csv(self, business_details, output_file='extracted_data.csv'):
"""Filter dataframe rows based on business_details and save to CSV."""
all_values = self._prepare_values(business_details)
matched_df = self.df[self.df.apply(self._row_matches, axis=1, args=(all_values,))]
matched_df.to_csv(output_file, index=False)