import os import sys import json import time import yaml import joblib import argparse import jinja2 import anthropic import pandas as pd from tqdm import tqdm from loguru import logger from openai import OpenAI from dotenv import load_dotenv import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold from utils import parse_json_garbage, compose_query try: logger.remove(0) logger.add(sys.stderr, level="INFO") except ValueError: pass load_dotenv() def llm( provider, model, system_prompt, user_content, delay:int = 0): """Invoke LLM service Argument -------- provider: str openai or anthropic model: str Model name for the API system_prompt: str System prompt for the API user_content: str User prompt for the API Return ------ response: str """ if delay: time.sleep(delay) if provider=='openai': client = OpenAI( organization = os.getenv('ORGANIZATION_ID')) chat_completion = client.chat.completions.create( messages=[ { "role": "system", "content": system_prompt }, { "role": "user", "content": user_content, } ], model = model, response_format = {"type": "json_object"}, temperature = 0, max_tokens = 4096, # stream = True ) response = chat_completion.choices[0].message.content elif provider=='anthropic': client = anthropic.Client(api_key=os.getenv('ANTHROPIC_API_KEY')) response = client.messages.create( model= model, system= system_prompt, messages=[ {"role": "user", "content": user_content} # <-- user prompt ], max_tokens = 4000 ) response = response.content[0].text elif provider=='google': genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) model = genai.GenerativeModel( model_name = model, system_instruction = system_prompt, generation_config={ "temperature": 0, "max_output_tokens": 8192, "response_mime_type": "application/json" }) safety_settings = { HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, } messages = [] # messages.append({ # 'role':'user', # 'parts': [f"System instruction: {system_prompt}"] # }) # response = model.generate_content(messages, safety_settings=safety_settings) # try: # messages.append({ # 'role': 'model', # 'parts': [response.text] # }) # except Exception as e: # logger.error(f"response.candidates -> {response.candidates}") # logger.error(f"error -> {e}") # messages.append({ # 'role': 'model', # 'parts': ["OK. I'm ready to help you."] # }) messages.append({ 'role': 'user', 'parts': [user_content] }) try: response = model.generate_content(messages, safety_settings=safety_settings, ) response = response.text except Exception as e: logger.error(f"Error (will still return response) -> {e}") logger.error(f"response.candidates -> {response.candidates}") return response else: raise Exception("Invalid provider") return response if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "-c", "--config", type=str, default='config/config.yml', help="Path to the configuration file") parser.add_argument( "-t", "--task", type=str, default='prepare_batch', choices=['extract', 'classify']) parser.add_argument( "-i", "--input_path", type=str, default='', ) parser.add_argument( "-o", "--output_path", type=str, default='', ) parser.add_argument( "-topn", "--topn", type=int, default=None ) args = parser.parse_args() # classes = ['小吃店', '日式料理(含居酒屋,串燒)', '火(鍋/爐)', '東南亞料理(不含日韓)', '海鮮熱炒', '特色餐廳(含雞、鵝、牛、羊肉)', '傳統餐廳', '燒烤', '韓式料理(含火鍋,烤肉)', '西餐廳(含美式,義式,墨式)', ] # backup_classes = [ '中式', '西式'] assert os.path.exists(args.config), f"File not found: {args.config}" config = yaml.safe_load(open(args.config, "r").read()) if args.task == 'extract': jenv = jinja2.Environment() template = jenv.from_string(config['extraction_prompt']) system_prompt = template.render( classes = config['classes'], traits = config['traits']) query = "山の迴饗" search_results = str([{"title": "山の迴饗", "snippet": "謝謝大家這麼支持山の迴饗 我們會繼續努力用心做出美味的料理 ————————— ⛰️ 山の迴饗地址:台東縣關山鎮中華路56號訂位專線:0975-957-056 · #山的迴饗 · #夢想起飛"}, {"title": "山的迴饗餐館- 店家介紹", "snippet": "營業登記資料 · 統一編號. 92433454 · 公司狀況. 營業中 · 公司名稱. 山的迴饗餐館 · 公司類型. 獨資 · 資本總額. 30000 · 所在地. 臺東縣關山鎮中福里中華路56號 · 使用發票."}, {"title": "關山漫遊| 💥山の迴饗x night bar", "snippet": "山の迴饗x night bar 即將在12/1號台東關山開幕! 別再煩惱池上、鹿野找不到宵夜餐酒館 各位敬請期待並關注我們✨ night bar❌山的迴饗 12/1 ..."}, {"title": "山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵 - 台灣美食網", "snippet": "山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵|台式三杯雞|滷肉飯|便當|CP美食營業時間 ; 星期一, 休息 ; 星期二, 10:00–14:00 16:00–21:00 ; 星期三, 10:00–14:00 16:00– ..."}, {"title": "便當|CP美食- 山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵", "snippet": "餐廳山的迴饗| 中西複合式餐廳|焗烤飯|義大利麵|台式三杯雞|滷肉飯|便當|CP美食google map 導航. 臺東縣關山鎮中華路56號 +886 975 957 056 ..."}, {"title": "山的迴饗餐館", "snippet": "山的迴饗餐館,統編:92433454,地址:臺東縣關山鎮中福里中華路56號,負責人姓名:周偉慈,設立日期:112年11月15日."}, {"title": "山的迴饗餐館", "snippet": "山的迴饗餐館. 資本總額(元), 30,000. 負責人, 周偉慈. 登記地址, 看地圖 臺東縣關山鎮中福里中華路56號 郵遞區號查詢. 設立日期, 2023-11-15. 資料管理 ..."}, {"title": "山的迴饗餐館, 公司統一編號92433454 - 食品業者登錄資料集", "snippet": "公司或商業登記名稱山的迴饗餐館的公司統一編號是92433454, 登錄項目是餐飲場所, 業者地址是台東縣關山鎮中福里中華路56號, 食品業者登錄字號是V-202257990-00001-5."}, {"title": "山的迴饗餐館, 公司統一編號92433454 - 食品業者登錄資料集", "snippet": "公司或商業登記名稱山的迴饗餐館的公司統一編號是92433454, 登錄項目是公司/商業登記, 業者地址是台東縣關山鎮中福里中華路56號, 食品業者登錄字號是V-202257990-00000-4 ..."}, {"title": "山的迴饗餐館", "snippet": "負責人, 周偉慈 ; 登記地址, 台東縣關山鎮中福里中華路56號 ; 公司狀態, 核准設立 「查詢最新營業狀況請至財政部稅務入口網 」 ; 資本額, 30,000元 ; 所在縣市 ..."}, {"title": "山的迴饗 | 關山美食|焗烤飯|酒吧|義大利麵|台式三杯雞|滷肉飯|便當|CP美食", "顧客評價": "324晚餐餐點豬排簡餐加白醬焗烤等等餐點。\t店家也提供免費的紅茶 綠茶 白開水 多種的調味料自取 總而言之 CP值真的很讚\t空間舒適涼爽,店員服務周到"}, {"title": "類似的店", "snippet": "['中國菜']\t['客家料理']\t['餐廳']\t['熟食店']\t['餐廳']"}, {"telephone_number": "0975 957 056"}]) user_content = f''' `query`: `{query}`, `search_results`: {search_results} ''' print(f"user_content -> {user_content}") resp = llm( config['provider'], config['model'], system_prompt, user_content) print(resp) elif args.task == 'classify': system_prompt = config['classification_prompt'] else: raise Exception("Invalid task")