Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import time | |
import json | |
import joblib | |
import math | |
import itertools | |
import argparse | |
import multiprocessing as mp | |
from typing import List | |
from pathlib import Path | |
import yaml | |
import jinja2 | |
import requests | |
import pandas as pd | |
from dotenv import load_dotenv | |
from serpapi import GoogleSearch | |
import tiktoken | |
from openai import OpenAI | |
from tqdm import tqdm | |
from loguru import logger | |
from model import llm | |
from data import get_leads, format_search_results | |
from utils import (parse_json_garbage, split_dataframe, merge_results, | |
combine_results, split_dict, format_df, | |
clean_quotes, compose_query, reverse_category2supercategory) | |
from batch import postprocess_result | |
from pipeline import (get_serp, get_condensed_result, get_organic_result, get_googlemap_results, | |
crawl_results, crawl_results_mp, | |
compose_extraction, extract_results, extract_results_mp, | |
compose_classification, classify_results, classify_results_mp, | |
compose_regularization, regularize_results, regularize_results_mp, | |
compose_filter, filter_results, filter_results_mp) | |
load_dotenv() | |
ORGANIZATION_ID = os.getenv('OPENAI_ORGANIZATION_ID') | |
SERP_API_KEY = os.getenv('SERP_APIKEY') | |
SERPER_API_KEY = os.getenv('SERPER_API_KEY') | |
def continue_missing(args): | |
""" | |
""" | |
data = get_leads(args.data_path) | |
n_data = data.shape[0] | |
formatted_results = pd.read_csv(os.path.join( args.output_dir, args.formatted_results_path)) | |
missing_indices = [] | |
for i in range(n_data): | |
if i not in formatted_results['index'].unique(): | |
logger.debug(f"{i} is not found") | |
missing_indices.append(i) | |
if len(missing_indices)==0: | |
logger.debug("No missing data") | |
return | |
missing_data = data.loc[missing_indices] | |
if not os.path.exists(args.output_missing_dir): | |
os.makedirs(args.output_missing_dir) | |
missing_data.to_csv( args.missing_data_path, index=False, header=False) | |
args.data_path = args.missing_data_path | |
args.output_dir = args.output_missing_dir | |
if missing_data.shape[0]<args.n_processes: | |
args.n_processes = 1 | |
main(args) | |
def main(args): | |
""" | |
Argument | |
args: argparse | |
Note | |
200 records | |
crawl: 585.3285548686981 | |
extract: 2791.631685256958(delay = 10) | |
classify: 2374.4915606975555(delay = 10) | |
""" | |
steps = args.steps | |
crawled_file_path = os.path.join( args.output_dir, args.crawled_file_path) if args.crawled_file_path is not None else None | |
extracted_file_path = os.path.join( args.output_dir, args.extracted_file_path) if args.extracted_file_path is not None else None | |
# classified_file_path = os.path.join( args.output_dir, args.classified_file_path) | |
# combined_file_path = os.path.join( args.output_dir, args.combined_file_path) | |
postprocessed_file_path = os.path.join( args.output_dir, args.postprocessed_file_path) if args.postprocessed_file_path is not None else None | |
# formatted_results_path = os.path.join( args.output_dir, args.formatted_results_path) | |
filtered_file_path = os.path.join( args.output_dir, args.filtered_file_path) if args.filtered_file_path is not None else None | |
regularized_file_path = os.path.join( args.output_dir, args.regularized_file_path) if args.regularized_file_path is not None else None | |
## 讀取資料名單 ## | |
data = get_leads(args.data_path) | |
## 進行爬蟲與分析 ## | |
if steps=='all' or steps=='crawl': | |
Path(crawled_file_path).parent.mkdir(parents=True, exist_ok=True) | |
crawled_results = crawl_results_mp( | |
data, | |
crawled_file_path, | |
serp_provider=args.serp_provider, | |
n_processes=args.n_processes | |
) | |
else: | |
sys.exit(0) | |
# crawled_results = { k:v[-5:] for k,v in crawled_results.items()} | |
# crawled_results['crawled_results'].to_csv( formatted_results_path, index=False) | |
## 篩選爬蟲結果 ## | |
# filtered_results = filter_results_mp( | |
# data = crawled_results['crawled_results'], | |
# filtered_file_path = filtered_file_path, | |
# provider = args.filter_provider, | |
# model = args.filter_model, | |
# n_processes = args.n_processes | |
# ) | |
# sys.exit(0) | |
## 方法 1: 擷取關鍵資訊與分類 ## | |
if steps=='all' or steps=='extract': | |
assert os.path.exists(crawled_file_path), f"# CRAWLED file not found: {crawled_file_path}" | |
crawled_results = joblib.load( open(crawled_file_path, "rb")) | |
extracted_results = extract_results_mp( | |
crawled_results = crawled_results['crawled_results'], # filtered_results['filtered_results'], # crawled_results['crawled_results'], | |
extracted_file_path = extracted_file_path, | |
classes = args.classes, | |
provider = args.extraction_provider, # 'openai', # args.provider, | |
model = args.extraction_model, # 'gpt-3.5-turbo-0125', # args.model, | |
n_processes = args.n_processes | |
) | |
else: | |
sys.exit(0) | |
## 方法2: 直接對爬蟲結果分類 ## | |
# classified_results = classify_results_mp( | |
# extracted_results['extracted_results'], | |
# classified_file_path, | |
# classes = args.classes, | |
# backup_classes = args.backup_classes, | |
# provider = args.provider, | |
# model = args.model, | |
# n_processes = args.n_processes | |
# ) | |
## 合併分析結果 ## | |
# combined_results = combine_results( | |
# classified_results['classified_results'], | |
# combined_file_path, | |
# src_column = 'classified_category', | |
# tgt_column = 'category', | |
# strategy = args.strategy | |
# ) | |
## 正規化分類結果 ## | |
if steps=='all' or steps=='regularize': | |
assert os.path.exists(args.extracted_file_path), f"# extracted result file not found: {args.extracted_file_path}" | |
extracted_results = joblib.load( open(extracted_file_path, "rb")) | |
regularize_results = regularize_results_mp( | |
extracted_results['extracted_results'], | |
regularized_file_path, | |
provider = args.regularization_provider, # 'google', # 'openai', # args.provider, | |
model = args.regularization_model # 'gemini-1.5-flash' # 'gpt-3.5-turbo-0125' # args.model | |
) | |
else: | |
sys.exit(0) | |
## 後處理分析結果 ## | |
if steps=='all' or steps=='postprocess': | |
assert os.path.exists(args.regularized_file_path), f"# extracted result file not found: {args.extracted_file_path}" | |
regularize_results = joblib.load( open(regularized_file_path, "rb")) | |
postprossed_results = postprocess_result( | |
regularize_results['regularized_results'], # extracted_results['extracted_results'], # combined_results, | |
postprocessed_file_path, | |
category2supercategory | |
) | |
else: | |
sys.exit(0) | |
if __name__=='__main__': | |
base = "https://serpapi.com/search.json" | |
engine = 'google' | |
google_domain = 'google.com.tw' | |
gl = 'tw' | |
lr = 'lang_zh-TW' | |
n_processes = 4 | |
client = OpenAI( organization = ORGANIZATION_ID) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default='config/config.yml', help="Path to the configuration file") | |
parser.add_argument("--data_path", type=str, default="data/餐廳類型分類.xlsx - 測試清單.csv") | |
parser.add_argument("--missing_data_path", type=str, default="data/missing/missing.csv") | |
parser.add_argument("--task", type=str, default="new", choices = ["new", "continue"], help="new or continue") | |
parser.add_argument("--steps", type=str, default="all", choices = ["all", "crawl", "extract", "regularize", "postprocess"], help="new or continue") | |
parser.add_argument("--output_dir", type=str, help='output directory') | |
parser.add_argument("--output_missing_dir", type=str, help='output missing directory') | |
parser.add_argument("--classified_file_path", type=str, default="classified_results.joblib") | |
parser.add_argument("--extracted_file_path", type=str, default="extracted_results.joblib") | |
parser.add_argument("--crawled_file_path", type=str, default="crawled_results.joblib") | |
parser.add_argument("--combined_file_path", type=str, default="combined_results.joblib") | |
parser.add_argument("--regularized_file_path", type=str, default="regularized_results.joblib") | |
parser.add_argument("--postprocessed_file_path", type=str, default="postprocessed_results.csv") | |
parser.add_argument("--formatted_results_path", type=str, default="formatted_results.csv") | |
parser.add_argument("--filtered_file_path", type=str, default="filtered_results.csv") | |
# parser.add_argument("--classes", type=list, default=['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', '西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)', '西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)', '早餐']) | |
parser.add_argument("--classes", type=list, default=['小吃店','日式料理(含居酒屋,串燒)','火(鍋/爐)','東南亞料理(不含日韓)','海鮮熱炒','特色餐廳(含雞、鵝、牛、羊肉)','釣蝦場','傳統餐廳','燒烤','韓式料理(含火鍋,烤肉)','PUB(Live Band)','PUB(一般,含Lounge)','PUB(電音\舞場)','五星級飯店','自助KTV(含連鎖,庭園自助)','西餐廳(含美式,義式,墨式)','咖啡廳(泡沫紅茶)','飯店(星級/旅館,不含五星級)','運動休閒館(含球類練習場,飛鏢等)','西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)','西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)','早餐'] ) | |
# `小吃店`,`日式料理(含居酒屋,串燒)`,`火(鍋/爐)`,`東南亞料理(不含日韓)`,`海鮮熱炒`,`特色餐廳(含雞、鵝、牛、羊肉)`,`釣蝦場`,`傳統餐廳`,`燒烤`,`韓式料理(含火鍋,烤肉)`,`PUB(Live Band)`,`PUB(一般,含Lounge)`,`PUB(電音\舞場)`,`五星級飯店`,`自助KTV(含連鎖,庭園自助)`,`西餐廳(含美式,義式,墨式)`,`咖啡廳(泡沫紅茶)`,`飯店(星級/旅館,不含五星級)`,`運動休閒館(含球類練習場,飛鏢等)`,`西餐廳(餐酒館、酒吧、飛鏢吧、pub、lounge bar)`,`西餐廳(土耳其、漢堡、薯條、法式、歐式、印度)`,`早餐` | |
parser.add_argument("--backup_classes", type=list, default=['中式', '西式']) | |
parser.add_argument("--strategy", type=str, default='patch', choices=['replace', 'patch']) | |
parser.add_argument("--filter_provider", type=str, default='google', choices=['google', 'openai', 'anthropic']) | |
parser.add_argument("--filter_model", type=str, default='gemini-1.5-flash', choices=[ 'claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'gpt-4o', 'gpt-4o-mini', 'gemini-1.5-flash']) | |
parser.add_argument("--extraction_provider", type=str, default='openai', choices=['google', 'openai', 'anthropic']) | |
parser.add_argument("--extraction_model", type=str, default='gpt-3.5-turbo-0125', choices=[ 'claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'gpt-4o', 'gpt-4o-mini', 'gemini-1.5-flash']) | |
parser.add_argument("--regularization_provider", type=str, default='google', choices=['google', 'openai', 'anthropic']) | |
parser.add_argument("--regularization_model", type=str, default='gemini-1.5-flash', choices=['claude-3-5-sonnet-20240620', 'claude-3-sonnet-20240229', 'claude-3-haiku-20240307', 'gpt-3.5-turbo-0125', 'gpt-4-0125-preview', 'gpt-4o', 'gpt-4o-mini', 'gemini-1.5-flash']) | |
parser.add_argument("--serp_provider", type=str, default='serp', choices=['serp', 'serper']) | |
parser.add_argument("--n_processes", type=int, default=4) | |
args = parser.parse_args() | |
config = yaml.safe_load(open(args.config,"r").read()) | |
category2supercategory = config['category2supercategory'] | |
supercategory2category = reverse_category2supercategory(category2supercategory) | |
if args.task == 'new': | |
main(args) | |
elif args.task == 'continue': | |
continue_missing(args) | |
else: | |
raise Exception(f"Task {args.task} not implemented") | |