refactor: optimize the structure
Browse files- autotab.py +32 -34
autotab.py
CHANGED
@@ -32,7 +32,14 @@ class AutoTab:
|
|
32 |
self.save_every = save_every
|
33 |
self.api_keys = api_keys
|
34 |
self.base_url = base_url
|
|
|
35 |
self.request_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# βββ IO βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
38 |
|
@@ -67,61 +74,60 @@ class AutoTab:
|
|
67 |
|
68 |
# βββ In-Context Learning ββββββββββββββββββββββββββββββββββββββββββββββ
|
69 |
|
70 |
-
def derive_incontext(
|
71 |
-
self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str]
|
72 |
-
) -> str:
|
73 |
"""Derive the in-context prompt with angle brackets."""
|
74 |
-
|
75 |
in_context = ""
|
76 |
-
for i in range(
|
77 |
in_context += "".join(
|
78 |
-
f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
|
79 |
-
for col in
|
80 |
)
|
81 |
in_context += "".join(
|
82 |
-
f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
|
83 |
-
for col in
|
84 |
)
|
85 |
in_context += "\n"
|
86 |
return in_context
|
87 |
|
88 |
-
def predict_output(
|
89 |
-
self, in_context: str, input_data: pd.DataFrame, input_fields: str
|
90 |
-
):
|
91 |
"""Predict the output values for the given input data using the API."""
|
92 |
query = (
|
93 |
self.instruction
|
94 |
+ "\n\n"
|
95 |
-
+ in_context
|
96 |
+ "".join(
|
97 |
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
|
98 |
-
for col in input_fields
|
99 |
)
|
100 |
)
|
101 |
self.query_example = query
|
102 |
output = self.openai_request(query)
|
103 |
return output
|
104 |
|
105 |
-
def extract_fields(
|
106 |
-
self, response: str, output_columns: list[str]
|
107 |
-
) -> dict[str, str]:
|
108 |
"""Extract fields from the response text based on output columns."""
|
109 |
extracted = {}
|
110 |
-
for col in
|
111 |
field = col.replace("[Output] ", "")
|
112 |
match = re.search(f"<{field}>(.*?)</{field}>", response)
|
113 |
extracted[col] = match.group(1) if match else ""
|
|
|
|
|
114 |
return extracted
|
115 |
|
116 |
# βββ Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
117 |
|
118 |
-
def _predict_and_extract(self,
|
119 |
"""Helper function to predict and extract fields for a single row."""
|
120 |
-
|
121 |
-
|
122 |
-
)
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
125 |
|
126 |
def batch_prediction(self, start_index: int, end_index: int):
|
127 |
"""Process a batch of predictions asynchronously."""
|
@@ -134,16 +140,8 @@ class AutoTab:
|
|
134 |
self.data.at[i, field_name] = extracted_fields.get(field_name, "")
|
135 |
|
136 |
def run(self):
|
137 |
-
|
138 |
-
self.
|
139 |
-
self.data, self.input_fields, self.output_fields
|
140 |
-
)
|
141 |
-
|
142 |
-
self.num_data = len(self.data)
|
143 |
-
self.num_examples = len(self.data.dropna(subset=self.output_fields))
|
144 |
-
|
145 |
-
tqdm_bar = tqdm(total=self.num_data - self.num_examples, leave=False)
|
146 |
-
for start in range(self.num_examples, self.num_data, self.save_every):
|
147 |
tqdm_bar.update(min(self.save_every, self.num_data - start))
|
148 |
end = min(start + self.save_every, self.num_data)
|
149 |
try:
|
|
|
32 |
self.save_every = save_every
|
33 |
self.api_keys = api_keys
|
34 |
self.base_url = base_url
|
35 |
+
|
36 |
self.request_count = 0
|
37 |
+
self.failed_count = 0
|
38 |
+
self.data, self.input_fields, self.output_fields = self.load_excel()
|
39 |
+
self.in_context = self.derive_incontext()
|
40 |
+
self.num_data = len(self.data)
|
41 |
+
self.num_example = len(self.data.dropna(subset=self.output_fields))
|
42 |
+
self.num_missing = self.num_data - self.num_example
|
43 |
|
44 |
# βββ IO βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
45 |
|
|
|
74 |
|
75 |
# βββ In-Context Learning ββββββββββββββββββββββββββββββββββββββββββββββ
|
76 |
|
77 |
+
def derive_incontext(self) -> str:
|
|
|
|
|
78 |
"""Derive the in-context prompt with angle brackets."""
|
79 |
+
examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
|
80 |
in_context = ""
|
81 |
+
for i in range(len(examples)):
|
82 |
in_context += "".join(
|
83 |
+
f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
|
84 |
+
for col in self.input_fields
|
85 |
)
|
86 |
in_context += "".join(
|
87 |
+
f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
|
88 |
+
for col in self.output_fields
|
89 |
)
|
90 |
in_context += "\n"
|
91 |
return in_context
|
92 |
|
93 |
+
def predict_output(self, input_data: pd.DataFrame):
|
|
|
|
|
94 |
"""Predict the output values for the given input data using the API."""
|
95 |
query = (
|
96 |
self.instruction
|
97 |
+ "\n\n"
|
98 |
+
+ self.in_context
|
99 |
+ "".join(
|
100 |
f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
|
101 |
+
for col in self.input_fields
|
102 |
)
|
103 |
)
|
104 |
self.query_example = query
|
105 |
output = self.openai_request(query)
|
106 |
return output
|
107 |
|
108 |
+
def extract_fields(self, response: str) -> dict[str, str]:
|
|
|
|
|
109 |
"""Extract fields from the response text based on output columns."""
|
110 |
extracted = {}
|
111 |
+
for col in self.output_fields:
|
112 |
field = col.replace("[Output] ", "")
|
113 |
match = re.search(f"<{field}>(.*?)</{field}>", response)
|
114 |
extracted[col] = match.group(1) if match else ""
|
115 |
+
if any(extracted[col] == "" for col in self.output_fields):
|
116 |
+
self.failed_count += 1
|
117 |
return extracted
|
118 |
|
119 |
# βββ Engine βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
120 |
|
121 |
+
def _predict_and_extract(self, row: int) -> dict[str, str]:
|
122 |
"""Helper function to predict and extract fields for a single row."""
|
123 |
+
|
124 |
+
# If any output field is empty, predict the output
|
125 |
+
if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
|
126 |
+
prediction = self.predict_output(self.data.iloc[row])
|
127 |
+
extracted_fields = self.extract_fields(prediction)
|
128 |
+
return extracted_fields
|
129 |
+
else:
|
130 |
+
return {col: self.data.at[row, col] for col in self.output_fields}
|
131 |
|
132 |
def batch_prediction(self, start_index: int, end_index: int):
|
133 |
"""Process a batch of predictions asynchronously."""
|
|
|
140 |
self.data.at[i, field_name] = extracted_fields.get(field_name, "")
|
141 |
|
142 |
def run(self):
|
143 |
+
tqdm_bar = tqdm(total=self.num_data, leave=False)
|
144 |
+
for start in range(0, self.num_data, self.save_every):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
tqdm_bar.update(min(self.save_every, self.num_data - start))
|
146 |
end = min(start + self.save_every, self.num_data)
|
147 |
try:
|