Ki-Seki commited on
Commit
f5cc66a
β€’
1 Parent(s): 24fa1b6

refactor: optimize the structure

Browse files
Files changed (1) hide show
  1. autotab.py +32 -34
autotab.py CHANGED
@@ -32,7 +32,14 @@ class AutoTab:
32
  self.save_every = save_every
33
  self.api_keys = api_keys
34
  self.base_url = base_url
 
35
  self.request_count = 0
 
 
 
 
 
 
36
 
37
  # ─── IO ───────────────────────────────────────────────────────────────
38
 
@@ -67,61 +74,60 @@ class AutoTab:
67
 
68
  # ─── In-Context Learning ──────────────────────────────────────────────
69
 
70
- def derive_incontext(
71
- self, data: pd.DataFrame, input_columns: list[str], output_columns: list[str]
72
- ) -> str:
73
  """Derive the in-context prompt with angle brackets."""
74
- n = min(self.max_examples, len(data.dropna(subset=output_columns)))
75
  in_context = ""
76
- for i in range(n):
77
  in_context += "".join(
78
- f"<{col.replace('[Input] ', '')}>{data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
79
- for col in input_columns
80
  )
81
  in_context += "".join(
82
- f"<{col.replace('[Output] ', '')}>{data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
83
- for col in output_columns
84
  )
85
  in_context += "\n"
86
  return in_context
87
 
88
- def predict_output(
89
- self, in_context: str, input_data: pd.DataFrame, input_fields: str
90
- ):
91
  """Predict the output values for the given input data using the API."""
92
  query = (
93
  self.instruction
94
  + "\n\n"
95
- + in_context
96
  + "".join(
97
  f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
98
- for col in input_fields
99
  )
100
  )
101
  self.query_example = query
102
  output = self.openai_request(query)
103
  return output
104
 
105
- def extract_fields(
106
- self, response: str, output_columns: list[str]
107
- ) -> dict[str, str]:
108
  """Extract fields from the response text based on output columns."""
109
  extracted = {}
110
- for col in output_columns:
111
  field = col.replace("[Output] ", "")
112
  match = re.search(f"<{field}>(.*?)</{field}>", response)
113
  extracted[col] = match.group(1) if match else ""
 
 
114
  return extracted
115
 
116
  # ─── Engine ───────────────────────────────────────────────────────────
117
 
118
- def _predict_and_extract(self, i: int) -> dict[str, str]:
119
  """Helper function to predict and extract fields for a single row."""
120
- prediction = self.predict_output(
121
- self.in_context, self.data.iloc[i], self.input_fields
122
- )
123
- extracted_fields = self.extract_fields(prediction, self.output_fields)
124
- return extracted_fields
 
 
 
125
 
126
  def batch_prediction(self, start_index: int, end_index: int):
127
  """Process a batch of predictions asynchronously."""
@@ -134,16 +140,8 @@ class AutoTab:
134
  self.data.at[i, field_name] = extracted_fields.get(field_name, "")
135
 
136
  def run(self):
137
- self.data, self.input_fields, self.output_fields = self.load_excel()
138
- self.in_context = self.derive_incontext(
139
- self.data, self.input_fields, self.output_fields
140
- )
141
-
142
- self.num_data = len(self.data)
143
- self.num_examples = len(self.data.dropna(subset=self.output_fields))
144
-
145
- tqdm_bar = tqdm(total=self.num_data - self.num_examples, leave=False)
146
- for start in range(self.num_examples, self.num_data, self.save_every):
147
  tqdm_bar.update(min(self.save_every, self.num_data - start))
148
  end = min(start + self.save_every, self.num_data)
149
  try:
 
32
  self.save_every = save_every
33
  self.api_keys = api_keys
34
  self.base_url = base_url
35
+
36
  self.request_count = 0
37
+ self.failed_count = 0
38
+ self.data, self.input_fields, self.output_fields = self.load_excel()
39
+ self.in_context = self.derive_incontext()
40
+ self.num_data = len(self.data)
41
+ self.num_example = len(self.data.dropna(subset=self.output_fields))
42
+ self.num_missing = self.num_data - self.num_example
43
 
44
  # ─── IO ───────────────────────────────────────────────────────────────
45
 
 
74
 
75
  # ─── In-Context Learning ──────────────────────────────────────────────
76
 
77
+ def derive_incontext(self) -> str:
 
 
78
  """Derive the in-context prompt with angle brackets."""
79
+ examples = self.data.dropna(subset=self.output_fields)[: self.max_examples]
80
  in_context = ""
81
+ for i in range(len(examples)):
82
  in_context += "".join(
83
+ f"<{col.replace('[Input] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Input] ', '')}>\n"
84
+ for col in self.input_fields
85
  )
86
  in_context += "".join(
87
+ f"<{col.replace('[Output] ', '')}>{self.data[col].iloc[i]}</{col.replace('[Output] ', '')}>\n"
88
+ for col in self.output_fields
89
  )
90
  in_context += "\n"
91
  return in_context
92
 
93
+ def predict_output(self, input_data: pd.DataFrame):
 
 
94
  """Predict the output values for the given input data using the API."""
95
  query = (
96
  self.instruction
97
  + "\n\n"
98
+ + self.in_context
99
  + "".join(
100
  f"<{col.replace('[Input] ', '')}>{input_data[col]}</{col.replace('[Input] ', '')}>\n"
101
+ for col in self.input_fields
102
  )
103
  )
104
  self.query_example = query
105
  output = self.openai_request(query)
106
  return output
107
 
108
+ def extract_fields(self, response: str) -> dict[str, str]:
 
 
109
  """Extract fields from the response text based on output columns."""
110
  extracted = {}
111
+ for col in self.output_fields:
112
  field = col.replace("[Output] ", "")
113
  match = re.search(f"<{field}>(.*?)</{field}>", response)
114
  extracted[col] = match.group(1) if match else ""
115
+ if any(extracted[col] == "" for col in self.output_fields):
116
+ self.failed_count += 1
117
  return extracted
118
 
119
  # ─── Engine ───────────────────────────────────────────────────────────
120
 
121
+ def _predict_and_extract(self, row: int) -> dict[str, str]:
122
  """Helper function to predict and extract fields for a single row."""
123
+
124
+ # If any output field is empty, predict the output
125
+ if any(pd.isnull(self.data.at[row, col]) for col in self.output_fields):
126
+ prediction = self.predict_output(self.data.iloc[row])
127
+ extracted_fields = self.extract_fields(prediction)
128
+ return extracted_fields
129
+ else:
130
+ return {col: self.data.at[row, col] for col in self.output_fields}
131
 
132
  def batch_prediction(self, start_index: int, end_index: int):
133
  """Process a batch of predictions asynchronously."""
 
140
  self.data.at[i, field_name] = extracted_fields.get(field_name, "")
141
 
142
  def run(self):
143
+ tqdm_bar = tqdm(total=self.num_data, leave=False)
144
+ for start in range(0, self.num_data, self.save_every):
 
 
 
 
 
 
 
 
145
  tqdm_bar.update(min(self.save_every, self.num_data - start))
146
  end = min(start + self.save_every, self.num_data)
147
  try: