amazon-feature-bullets-demo / src /few_shot_funcs.py
iarbel's picture
update openai model
5f31c25
raw
history blame contribute delete
No virus
5.53 kB
import os
import re
import openai
import inflect
import pandas as pd
from typing import Dict
from datasets import load_dataset
from huggingface_hub import login
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.vectorstores.utils import DistanceStrategy
# Get OpenAI and huggingface-hub keys
openai.api_key = os.environ.get('OPENAI_API_KEY')
openai.organization = os.environ.get('OPENAI_ORG')
login(os.environ.get('HUB_KEY'))
# Constants
FS_COLUMNS = ['asin', 'category', 'title', 'tech_process', 'labels']
MAX_TOKENS = 700
USER_TXT = 'Write feature-bullets for an Amazon product page. ' \
'Title: {title}. Technical details: {tech_data}.\n\n### Feature-bullets:'
# Load few-shot dataset
FS_DATASET = load_dataset('iarbel/amazon-product-data-filter', split='validation')
# Prepare Pandas DFs with the relevant columns
FS_DS = FS_DATASET.to_pandas()[FS_COLUMNS]
# Load vector store
DB = FAISS.load_local('data/vector_stores/amazon-product-embedding', OpenAIEmbeddings(),
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
class Conversation:
"""
A class to construct conversations with the ChatAPI
"""
def __init__(self):
self.messages = [{'role': 'system',
'content': 'You are a helpful assistant. Your task is to write feature-bullets for an Amazon product page.'}]
def add_message(self, role: str, content: str) -> None:
# Validate inputs
role = role.lower()
last_role = self.messages[-1]['role']
if role not in ['user', 'assistant']:
raise ValueError('Roles can be "user" or "assistant" only')
if role == 'user' and last_role not in ['system', 'assistant']:
raise ValueError('"user" message can only follow "assistant" message')
elif role == 'assistant' and last_role != 'user':
raise ValueError('"assistant" message can only follow "user" message')
message = {"role": role, "content": content}
self.messages.append(message)
def api_call(messages: Dict[str, str], temperature: float = 0.7, top_p: int = 1, n_responses: int = 1) -> dict:
"""
A function to call the ChatAPI. Taken in a conversation, and the optional params temperature (controls randomness) and n_responses
"""
params = {'model': 'gpt-4o-mini', 'messages': messages, 'temperature': temperature, 'max_tokens': MAX_TOKENS, 'n': n_responses, 'top_p': top_p}
response = openai.ChatCompletion.create(**params)
text = [response['choices'][i]['message']['content'] for i in range(n_responses)]
out = {'object': 'chat', 'usage': response['usage']._previous, 'text': text}
return out
class FewShotData:
def __init__(self, few_shot_df: pd.DataFrame, vector_db: FAISS):
self.few_shot_df = few_shot_df
self.vector_db = vector_db
def extract_few_shot_data(self, target_title: str, k_shot: int = 2, **db_kwargs) -> pd.DataFrame:
# Find relevant products
target_title_vector = OpenAIEmbeddings().embed_query(target_title)
similarity_list_mmr = self.vector_db.max_marginal_relevance_search_with_score_by_vector(target_title_vector, k=k_shot, **db_kwargs)
few_shot_titles = [i[0].page_content for i in similarity_list_mmr]
# Extract relevant data
few_shot_data = self.few_shot_df[self.few_shot_df['title'].isin(few_shot_titles)][['title', 'tech_process', 'labels']]
return few_shot_data
def construct_few_shot_conversation(self, target_title: str, target_tech_data: str, few_shot_data: pd.DataFrame) -> Conversation:
# Structure the few-shott data
fs_titles = few_shot_data['title'].to_list()
fs_tech_data = few_shot_data['tech_process'].to_list()
fs_labels = few_shot_data['labels'].to_list()
# Init a conversation, populate with few-shot data
conv = Conversation()
for title, tech_data, lables in zip(fs_titles, fs_tech_data, fs_labels):
conv.add_message('user', USER_TXT.format(title=title, tech_data=tech_data))
conv.add_message('assistant',lables)
# Add the final user prompt
conv.add_message('user', USER_TXT.format(title=target_title, tech_data=target_tech_data))
return conv
def return_is_are(text: str) -> str:
engine = inflect.engine()
res = 'is' if not engine.singular_noun(text) else 'are'
return res
def format_tech_as_str(tech_data):
tech_format = [f'{k} {return_is_are(k)} {v}' for k, v in tech_data.to_numpy() if k and v]
tech_str = '. '.join(tech_format)
return tech_str
def generate_data(title: str, tech_process: str, few_shot_df: pd.DataFrame, vector_db: FAISS) -> str:
fs_example = FewShotData(few_shot_df=few_shot_df, vector_db=vector_db)
fs_data = fs_example.extract_few_shot_data(target_title=title, k_shot=2)
fs_conv = fs_example.construct_few_shot_conversation(target_title=title,
target_tech_data=tech_process,
few_shot_data=fs_data)
api_res = api_call(fs_conv.messages, temperature=0.7)
feature_bullets = "## Feature-Bullets\n" + api_res['text'][0]
return feature_bullets
def check_url_structure(url: str) -> bool:
pattern = r"https://www.amazon.com(/.+)?/dp/[a-zA-Z0-9]{10}/?$"
return bool(re.match(pattern, url))