File size: 2,696 Bytes
d54fa91
2bb5de3
d54fa91
cbbf201
d54fa91
 
 
 
 
 
 
 
 
 
2bb5de3
d54fa91
 
 
 
bc41f37
d54fa91
 
 
 
 
 
 
 
 
 
 
 
cbbf201
d54fa91
 
 
 
 
 
bc41f37
cbbf201
 
d54fa91
cbbf201
 
 
d54fa91
cbbf201
d54fa91
 
cbbf201
 
 
 
d54fa91
cbbf201
 
 
 
 
 
d54fa91
cbbf201
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from dotenv import load_dotenv
import pandas as pd
import io
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import ChatPromptTemplate
from langchain.tools import PythonAstREPLTool
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.chat_models import ChatOpenAI
from src.types import TableMapping
from src.vars import NUM_ROWS_TO_RETURN
from src.prompt import DATA_SCIENTIST_PROMPT_STR, SPEC_WRITER_PROMPT_STR, ENGINEER_PROMPT_STR

load_dotenv()

DATA_DIR_PATH = os.path.join(os.path.dirname(__file__), 'data')
SYNTHETIC_DATA_DIR_PATH = os.path.join(DATA_DIR_PATH, 'synthetic')

BASE_MODEL = ChatOpenAI(
    model_name='gpt-4',
    temperature=0,
)

def get_dataframes():
    source = pd.read_csv(os.path.join(SYNTHETIC_DATA_DIR_PATH, 'legal_entries_a.csv'))
    template = pd.read_csv(os.path.join(SYNTHETIC_DATA_DIR_PATH, 'legal_template.csv'))
    return source, template

def get_data_str_from_df_for_prompt(df, num_rows_to_return=NUM_ROWS_TO_RETURN):
    return f'<df>\n{df.head(num_rows_to_return).to_markdown()}\n</df>'

def get_table_mapping(source_df, template_df):
    table_mapping_parser = PydanticOutputParser(pydantic_object=TableMapping)
    analyst_prompt = ChatPromptTemplate.from_template(
        template=DATA_SCIENTIST_PROMPT_STR, 
        partial_variables={'format_instructions': table_mapping_parser.get_format_instructions()},
    )

    mapping_chain = analyst_prompt | BASE_MODEL | table_mapping_parser
    table_mapping: TableMapping = mapping_chain.invoke({"source_1_csv_str": get_data_str_from_df_for_prompt(source_df), "target_csv_str": get_data_str_from_df_for_prompt(template_df)})
    return pd.DataFrame(table_mapping.dict()['table_mappings'])

def _sanitize_python_output(text: str):
    _, after = text.split("```python")
    return after.split("```")[0]

def generate_mapping_code(table_mapping_df) -> str:
    writer_prompt = ChatPromptTemplate.from_template(SPEC_WRITER_PROMPT_STR)
    engineer_prompt = ChatPromptTemplate.from_template(ENGINEER_PROMPT_STR)
    
    writer_chain = writer_prompt | BASE_MODEL | StrOutputParser()
    engineer_chain = {"spec_str": writer_chain} | engineer_prompt | BASE_MODEL | StrOutputParser() | _sanitize_python_output
    return engineer_chain.invoke({"table_mapping": str(table_mapping_df.to_dict())})

def process_csv_text(temp_file):
    if isinstance(temp_file, str):
      df = pd.read_csv(io.StringIO(temp_file))
    else:
      df = pd.read_csv(temp_file.name)
    return df

def transform_source(source_df, code_text: str):
    return PythonAstREPLTool(locals={'source_df': source_df}).run(code_text)