Spaces:
Runtime error
Runtime error
Commit
·
948e91c
1
Parent(s):
60274d1
add missing data workflow
Browse files
model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import os
|
2 |
-
import json
|
3 |
import argparse
|
|
|
4 |
|
5 |
from dotenv import load_dotenv
|
6 |
import anthropic
|
@@ -10,7 +10,7 @@ from utils import parse_json_garbage
|
|
10 |
|
11 |
load_dotenv()
|
12 |
|
13 |
-
def llm( provider, model, system_prompt, user_content):
|
14 |
"""Invoke LLM service
|
15 |
Argument
|
16 |
--------
|
@@ -26,6 +26,9 @@ def llm( provider, model, system_prompt, user_content):
|
|
26 |
------
|
27 |
response: str
|
28 |
"""
|
|
|
|
|
|
|
29 |
if provider=='openai':
|
30 |
client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
|
31 |
chat_completion = client.chat.completions.create(
|
|
|
1 |
import os
|
|
|
2 |
import argparse
|
3 |
+
import time
|
4 |
|
5 |
from dotenv import load_dotenv
|
6 |
import anthropic
|
|
|
10 |
|
11 |
load_dotenv()
|
12 |
|
13 |
+
def llm( provider, model, system_prompt, user_content, delay:int = 10):
|
14 |
"""Invoke LLM service
|
15 |
Argument
|
16 |
--------
|
|
|
26 |
------
|
27 |
response: str
|
28 |
"""
|
29 |
+
if delay:
|
30 |
+
time.sleep(delay)
|
31 |
+
|
32 |
if provider=='openai':
|
33 |
client = OpenAI( organization = os.getenv('ORGANIZATION_ID'))
|
34 |
chat_completion = client.chat.completions.create(
|
sheet.py
CHANGED
@@ -165,7 +165,7 @@ def classify_results(
|
|
165 |
label = parse_json_garbage(pred_cls)['category']
|
166 |
labels.append(label)
|
167 |
except Exception as e:
|
168 |
-
print(f"# CLASSIFICATION error -> evidence: {
|
169 |
labels.append("")
|
170 |
empty_indices.append(idx)
|
171 |
|
@@ -488,10 +488,58 @@ def split_dataframe( df: pd.DataFrame, n_processes: int = 4) -> list:
|
|
488 |
n_per_process = math.ceil(n / n_processes)
|
489 |
return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)]
|
490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
def main(args):
|
492 |
"""
|
493 |
Argument
|
494 |
args: argparse
|
|
|
|
|
|
|
|
|
|
|
495 |
"""
|
496 |
crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path)
|
497 |
extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path)
|
@@ -501,11 +549,11 @@ def main(args):
|
|
501 |
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
|
502 |
|
503 |
## 讀取資料名單 ##
|
504 |
-
data = get_leads(args.data_path)
|
505 |
|
506 |
## 進行爬蟲與分析 ##
|
507 |
crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes)
|
508 |
-
crawled_results = { k:v[-5:] for k,v in crawled_results.items()}
|
509 |
|
510 |
## 方法 1: 擷取關鍵資訊與分類 ##
|
511 |
extracted_results = extract_results_mp(
|
@@ -596,6 +644,7 @@ if __name__=='__main__':
|
|
596 |
|
597 |
parser = argparse.ArgumentParser()
|
598 |
parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv")
|
|
|
599 |
parser.add_argument("--output_dir", type=str, help='output directory')
|
600 |
parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib")
|
601 |
parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib")
|
@@ -606,9 +655,16 @@ if __name__=='__main__':
|
|
606 |
parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐'])
|
607 |
parser.add_argument("--backup_classes", type=list, default=['中式', '西式'])
|
608 |
parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch'])
|
609 |
-
parser.add_argument("--provider", type=str, default='
|
610 |
-
parser.add_argument("--model", type=str, default='
|
611 |
parser.add_argument("--n_processes", type=int, default=4)
|
612 |
args = parser.parse_args()
|
613 |
|
614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
label = parse_json_garbage(pred_cls)['category']
|
166 |
labels.append(label)
|
167 |
except Exception as e:
|
168 |
+
print(f"# CLASSIFICATION error: e -> {e}, user_content -> {user_content}, evidence: {evidence}")
|
169 |
labels.append("")
|
170 |
empty_indices.append(idx)
|
171 |
|
|
|
488 |
n_per_process = math.ceil(n / n_processes)
|
489 |
return [ df.iloc[i:i+n_per_process] for i in range(0, n, n_per_process)]
|
490 |
|
491 |
+
|
492 |
+
def continue_missing(args):
|
493 |
+
"""
|
494 |
+
"""
|
495 |
+
data = get_leads(args.data_path)
|
496 |
+
n_data = data.shape[0]
|
497 |
+
|
498 |
+
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
|
499 |
+
formatted_results = pd.read_csv(formatted_results_path)
|
500 |
+
missing_indices = []
|
501 |
+
for i in range(n_data):
|
502 |
+
if i not in formatted_results['index'].unique():
|
503 |
+
print(f"{i} is not found")
|
504 |
+
missing_indices.append(i)
|
505 |
+
|
506 |
+
crawled_results_path = os.path.join( args.output_dir, args.crawled_file_path)
|
507 |
+
crawled_results = joblib.load( open( crawled_results_path, "rb"))
|
508 |
+
crawled_results = crawled_results['crawled_results'].query( f"index in {missing_indices}")
|
509 |
+
print( crawled_results)
|
510 |
+
|
511 |
+
er = extract_results( crawled_results, classes = args.classes, provider = args.provider, model = args.model)
|
512 |
+
er = er['extracted_results']
|
513 |
+
print(er['category'])
|
514 |
+
|
515 |
+
postprossed_results = postprocess_result(
|
516 |
+
er,
|
517 |
+
"/tmp/postprocessed_results.joblib",
|
518 |
+
category2supercategory
|
519 |
+
)
|
520 |
+
|
521 |
+
out_formatted_results = format_output(
|
522 |
+
postprossed_results,
|
523 |
+
input_column = 'evidence',
|
524 |
+
output_column = 'formatted_evidence',
|
525 |
+
format_func = format_evidence
|
526 |
+
)
|
527 |
+
|
528 |
+
out_formatted_results.to_csv( "/tmp/formatted_results.missing.csv", index=False)
|
529 |
+
formatted_results = pd.concat([formatted_results, out_formatted_results], ignore_index=True)
|
530 |
+
formatted_results.sort_values(by='index', ascending=True, inplace=True)
|
531 |
+
formatted_results.to_csv( "/tmp/formatted_results.csv", index=False)
|
532 |
+
|
533 |
+
|
534 |
def main(args):
|
535 |
"""
|
536 |
Argument
|
537 |
args: argparse
|
538 |
+
Note
|
539 |
+
200 records
|
540 |
+
crawl: 585.3285548686981
|
541 |
+
extract: 2791.631685256958(delay = 10)
|
542 |
+
classify: 2374.4915606975555(delay = 10)
|
543 |
"""
|
544 |
crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path)
|
545 |
extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path)
|
|
|
549 |
formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path)
|
550 |
|
551 |
## 讀取資料名單 ##
|
552 |
+
data = get_leads(args.data_path)
|
553 |
|
554 |
## 進行爬蟲與分析 ##
|
555 |
crawled_results = crawl_results_mp( data, crawled_file_path, n_processes=args.n_processes)
|
556 |
+
# crawled_results = { k:v[-5:] for k,v in crawled_results.items()}
|
557 |
|
558 |
## 方法 1: 擷取關鍵資訊與分類 ##
|
559 |
extracted_results = extract_results_mp(
|
|
|
644 |
|
645 |
parser = argparse.ArgumentParser()
|
646 |
parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv")
|
647 |
+
parser.add_argument("--task", type=str, default="new", choices = ["new", "continue"], help="new or continue")
|
648 |
parser.add_argument("--output_dir", type=str, help='output directory')
|
649 |
parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib")
|
650 |
parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib")
|
|
|
655 |
parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐'])
|
656 |
parser.add_argument("--backup_classes", type=list, default=['中式', '西式'])
|
657 |
parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch'])
|
658 |
+
parser.add_argument("--provider", type=str, default='openai', choices=['openai', 'anthropic'])
|
659 |
+
parser.add_argument("--model", type=str, default='gpt-4-0125-preview', choices=['claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview'])
|
660 |
parser.add_argument("--n_processes", type=int, default=4)
|
661 |
args = parser.parse_args()
|
662 |
|
663 |
+
if args.task == 'new':
|
664 |
+
main(args)
|
665 |
+
elif args.task == 'continue':
|
666 |
+
continue_missing(args)
|
667 |
+
else:
|
668 |
+
raise Exception(f"Task {args.task} not implemented")
|
669 |
+
|
670 |
+
|
utils.py
CHANGED
@@ -1,9 +1,13 @@
|
|
|
|
1 |
import json
|
2 |
|
3 |
def parse_json_garbage(s):
|
4 |
s = s[next(idx for idx, c in enumerate(s) if c in "{["):]
|
|
|
|
|
|
|
5 |
try:
|
6 |
-
return json.loads(s)
|
7 |
except json.JSONDecodeError as e:
|
8 |
-
return json.loads(s
|
9 |
|
|
|
1 |
+
import re
|
2 |
import json
|
3 |
|
4 |
def parse_json_garbage(s):
|
5 |
s = s[next(idx for idx, c in enumerate(s) if c in "{["):]
|
6 |
+
print(s)
|
7 |
+
s = s[:next(idx for idx, c in enumerate(s) if c in "}]")+1]
|
8 |
+
print(s)
|
9 |
try:
|
10 |
+
return json.loads(re.sub("[//#].*","",s,flags=re.MULTILINE))
|
11 |
except json.JSONDecodeError as e:
|
12 |
+
return json.loads(re.sub("[//#].*","",s,flags=re.MULTILINE))
|
13 |
|