linpershey commited on
Commit
948e91c
1 Parent(s): 60274d1

add missing data workflow

Browse files
Files changed (3) hide show
  1. model.py +5 -2
  2. sheet.py +62 -6
  3. utils.py +6 -2
model.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
- import json
3
  import argparse
 
4
 
5
  from dotenv import load_dotenv
6
  import anthropic
@@ -10,7 +10,7 @@ from utils import parse_json_garbage
10
 
11
  load_dotenv()
12
 
13
- def llm( provider, model, system_prompt, user_content):
14
  """Invoke LLM service
15
  Argument
16
  --------
@@ -26,6 +26,9 @@ def llm( provider, model, system_prompt, user_content):
26
  ------
27
  response: str
28
  """
 
 
 
29
  if provider=='openai':
30
  client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
31
  chat_completion = client.chat.completions.create(
 
1
  import os
 
2
  import argparse
3
+ import time
4
 
5
  from dotenv import load_dotenv
6
  import anthropic
 
10
 
11
  load_dotenv()
12
 
13
+ def llm( provider, model, system_prompt, user_content, delay:int = 10):
14
  """Invoke LLM service
15
  Argument
16
  --------
 
26
  ------
27
  response: str
28
  """
29
+ if delay:
30
+ time.sleep(delay)
31
+
32
  if provider=='openai':
33
  client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
34
  chat_completion = client.chat.completions.create(
sheet.py CHANGED
@@ -165,7 +165,7 @@ def classify_results(
165
  label = parse_json_garbage(pred_cls)['category']
166
  labels.append(label)
167
  except Exception as e:
168
- print(f"# CLASSIFICATION error -> evidence: {e}")
169
  labels.append("")
170
  empty_indices.append(idx)
171
 
@@ -488,10 +488,58 @@ def split_dataframe( df: pd.DataFrame, n_processes: int = 4) -> list:
488
  n_per_process = math.ceil(n / n_processes)
489
  return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)]
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  def main(args):
492
  """
493
  Argument
494
  args: argparse
 
 
 
 
 
495
  """
496
  crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path)
497
  extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path)
@@ -501,11 +549,11 @@ def main(args):
501
  formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
502
 
503
  ## 讀取資料名單 ##
504
- data = get_leads(args.data_path).tail(5)
505
 
506
  ## 進行爬蟲與分析 ##
507
  crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes)
508
- crawled_results = { k:v[-5:] for k,v in crawled_results.items()}
509
 
510
  ## 方法 1: 擷取關鍵資訊與分類 ##
511
  extracted_results = extract_results_mp(
@@ -596,6 +644,7 @@ if __name__=='__main__':
596
 
597
  parser = argparse.ArgumentParser()
598
  parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv")
 
599
  parser.add_argument("--output_dir", type=str, help='output directory')
600
  parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib")
601
  parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib")
@@ -606,9 +655,16 @@ if __name__=='__main__':
606
  parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐'])
607
  parser.add_argument("--backup_classes", type=list, default=['中式', '西式'])
608
  parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch'])
609
- parser.add_argument("--provider", type=str, default='anthropic', choices=['openai', 'anthropic'])
610
- parser.add_argument("--model", type=str, default='claude-3-sonnet-20240229', choices=['claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview'])
611
  parser.add_argument("--n_processes", type=int, default=4)
612
  args = parser.parse_args()
613
 
614
- main(args)
 
 
 
 
 
 
 
 
165
  label = parse_json_garbage(pred_cls)['category']
166
  labels.append(label)
167
  except Exception as e:
168
+ print(f"# CLASSIFICATION error: e -> {e}, user_content -> {user_content}, evidence: {evidence}")
169
  labels.append("")
170
  empty_indices.append(idx)
171
 
 
488
  n_per_process = math.ceil(n / n_processes)
489
  return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)]
490
 
491
+
492
+ def continue_missing(args):
493
+ """
494
+ """
495
+ data = get_leads(args.data_path)
496
+ n_data = data.shape[0]
497
+
498
+ formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
499
+ formatted_results = pd.read_csv(formatted_results_path)
500
+ missing_indices = []
501
+ for i in range(n_data):
502
+ if i not in formatted_results['index'].unique():
503
+ print(f"{i} is not found")
504
+ missing_indices.append(i)
505
+
506
+ crawled_results_path = os.path.join( args.output_dir, args.crawled_file_path)
507
+ crawled_results = joblib.load( open( crawled_results_path, "rb"))
508
+ crawled_results = crawled_results['crawled_results'].query( f"index in {missing_indices}")
509
+ print( crawled_results)
510
+
511
+ er = extract_results( crawled_results, classes = args.classes, provider = args.provider, model = args.model)
512
+ er = er['extracted_results']
513
+ print(er['category'])
514
+
515
+ postprossed_results = postprocess_result(
516
+ er,
517
+ "/tmp/postprocessed_results.joblib",
518
+ category2supercategory
519
+ )
520
+
521
+ out_formatted_results = format_output(
522
+ postprossed_results,
523
+ input_column = 'evidence',
524
+ output_column = 'formatted_evidence',
525
+ format_func = format_evidence
526
+ )
527
+
528
+ out_formatted_results.to_csv( "/tmp/formatted_results.missing.csv", index=False)
529
+ formatted_results = pd.concat([formatted_results, out_formatted_results], ignore_index=True)
530
+ formatted_results.sort_values(by='index', ascending=True, inplace=True)
531
+ formatted_results.to_csv( "/tmp/formatted_results.csv", index=False)
532
+
533
+
534
  def main(args):
535
  """
536
  Argument
537
  args: argparse
538
+ Note
539
+ 200 records
540
+ crawl: 585.3285548686981
541
+ extract: 2791.631685256958(delay = 10)
542
+ classify: 2374.4915606975555(delay = 10)
543
  """
544
  crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path)
545
  extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path)
 
549
  formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
550
 
551
  ## 讀取資料名單 ##
552
+ data = get_leads(args.data_path)
553
 
554
  ## 進行爬蟲與分析 ##
555
  crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes)
556
+ # crawled_results = { k:v[-5:] for k,v in crawled_results.items()}
557
 
558
  ## 方法 1: 擷取關鍵資訊與分類 ##
559
  extracted_results = extract_results_mp(
 
644
 
645
  parser = argparse.ArgumentParser()
646
  parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv")
647
+ parser.add_argument("--task", type=str, default="new", choices = ["new", "continue"], help="new or continue")
648
  parser.add_argument("--output_dir", type=str, help='output directory')
649
  parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib")
650
  parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib")
 
655
  parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐'])
656
  parser.add_argument("--backup_classes", type=list, default=['中式', '西式'])
657
  parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch'])
658
+ parser.add_argument("--provider", type=str, default='openai', choices=['openai', 'anthropic'])
659
+ parser.add_argument("--model", type=str, default='gpt-4-0125-preview', choices=['claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview'])
660
  parser.add_argument("--n_processes", type=int, default=4)
661
  args = parser.parse_args()
662
 
663
+ if args.task == 'new':
664
+ main(args)
665
+ elif args.task == 'continue':
666
+ continue_missing(args)
667
+ else:
668
+ raise Exception(f"Task {args.task} not implemented")
669
+
670
+
utils.py CHANGED
@@ -1,9 +1,13 @@
 
1
  import json
2
 
3
  def parse_json_garbage(s):
4
  s = s[next(idx for idx, c in enumerate(s) if c in "{["):]
 
 
 
5
  try:
6
- return json.loads(s)
7
  except json.JSONDecodeError as e:
8
- return json.loads(s[:e.pos])
9
 
 
1
+ import re
2
  import json
3
 
4
  def parse_json_garbage(s):
5
  s = s[next(idx for idx, c in enumerate(s) if c in "{["):]
6
+ print(s)
7
+ s = s[:next(idx for idx, c in enumerate(s) if c in "}]")+1]
8
+ print(s)
9
  try:
10
+ return json.loads(re.sub("[//#].*","",s,flags=re.MULTILINE))
11
  except json.JSONDecodeError as e:
12
+ return json.loads(re.sub("[//#].*","",s,flags=re.MULTILINE))
13