File size: 5,528 Bytes
ba41c0a 00f57d4 8ec9ed5 00f57d4 8ec9ed5 1fe08c6 8ec9ed5 00f57d4 8ec9ed5 00f57d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import re
import openai
import inflect
import pandas as pd
from typing import Dict
from datasets import load_dataset
from huggingface_hub import login
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.vectorstores.utils import DistanceStrategy
# Get OpenAI and huggingface-hub keys
openai.api_key = os.environ.get('OPENAI_API_KEY')
openai.organization = os.environ.get('OPENAI_ORG')
login(os.environ.get('HUB_KEY'))
# Constants
FS_COLUMNS = ['asin', 'category', 'title', 'tech_process', 'labels']
MAX_TOKENS = 700
USER_TXT = 'Write feature-bullets for an Amazon product page. ' \
'Title: {title}. Technical details: {tech_data}.\n\n### Feature-bullets:'
# Load few-shot dataset
FS_DATASET = load_dataset('iarbel/amazon-product-data-filter', split='validation')
# Prepare Pandas DFs with the relevant columns
FS_DS = FS_DATASET.to_pandas()[FS_COLUMNS]
# Load vector store
DB = FAISS.load_local('data/vector_stores/amazon-product-embedding', OpenAIEmbeddings(),
distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
class Conversation:
"""
A class to construct conversations with the ChatAPI
"""
def __init__(self):
self.messages = [{'role': 'system',
'content': 'You are a helpful assistant. Your task is to write feature-bullets for an Amazon product page.'}]
def add_message(self, role: str, content: str) -> None:
# Validate inputs
role = role.lower()
last_role = self.messages[-1]['role']
if role not in ['user', 'assistant']:
raise ValueError('Roles can be "user" or "assistant" only')
if role == 'user' and last_role not in ['system', 'assistant']:
raise ValueError('"user" message can only follow "assistant" message')
elif role == 'assistant' and last_role != 'user':
raise ValueError('"assistant" message can only follow "user" message')
message = {"role": role, "content": content}
self.messages.append(message)
def api_call(messages: Dict[str, str], temperature: float = 0.7, top_p: int = 1, n_responses: int = 1) -> dict:
"""
A function to call the ChatAPI. Taken in a conversation, and the optional params temperature (controls randomness) and n_responses
"""
params = {'model': 'gpt-3.5-turbo', 'messages': messages, 'temperature': temperature, 'max_tokens': MAX_TOKENS, 'n': n_responses, 'top_p': top_p}
response = openai.ChatCompletion.create(**params)
text = [response['choices'][i]['message']['content'] for i in range(n_responses)]
out = {'object': 'chat', 'usage': response['usage']._previous, 'text': text}
return out
class FewShotData:
def __init__(self, few_shot_df: pd.DataFrame, vector_db: FAISS):
self.few_shot_df = few_shot_df
self.vector_db = vector_db
def extract_few_shot_data(self, target_title: str, k_shot: int = 2, **db_kwargs) -> pd.DataFrame:
# Find relevant products
target_title_vector = OpenAIEmbeddings().embed_query(target_title)
similarity_list_mmr = self.vector_db.max_marginal_relevance_search_with_score_by_vector(target_title_vector, k=k_shot, **db_kwargs)
few_shot_titles = [i[0].page_content for i in similarity_list_mmr]
# Extract relevant data
few_shot_data = self.few_shot_df[self.few_shot_df['title'].isin(few_shot_titles)][['title', 'tech_process', 'labels']]
return few_shot_data
def construct_few_shot_conversation(self, target_title: str, target_tech_data: str, few_shot_data: pd.DataFrame) -> Conversation:
# Structure the few-shott data
fs_titles = few_shot_data['title'].to_list()
fs_tech_data = few_shot_data['tech_process'].to_list()
fs_labels = few_shot_data['labels'].to_list()
# Init a conversation, populate with few-shot data
conv = Conversation()
for title, tech_data, lables in zip(fs_titles, fs_tech_data, fs_labels):
conv.add_message('user', USER_TXT.format(title=title, tech_data=tech_data))
conv.add_message('assistant',lables)
# Add the final user prompt
conv.add_message('user', USER_TXT.format(title=target_title, tech_data=target_tech_data))
return conv
def return_is_are(text: str) -> str:
engine = inflect.engine()
res = 'is' if not engine.singular_noun(text) else 'are'
return res
def format_tech_as_str(tech_data):
tech_format = [f'{k} {return_is_are(k)} {v}' for k, v in tech_data.to_numpy() if k and v]
tech_str = '. '.join(tech_format)
return tech_str
def generate_data(title: str, tech_process: str, few_shot_df: pd.DataFrame, vector_db: FAISS) -> str:
fs_example = FewShotData(few_shot_df=few_shot_df, vector_db=vector_db)
fs_data = fs_example.extract_few_shot_data(target_title=title, k_shot=2)
fs_conv = fs_example.construct_few_shot_conversation(target_title=title,
target_tech_data=tech_process,
few_shot_data=fs_data)
api_res = api_call(fs_conv.messages, temperature=0.7)
feature_bullets = "## Feature-Bullets\n" + api_res['text'][0]
return feature_bullets
def check_url_structure(url: str) -> bool:
pattern = r"https://www.amazon.com(/.+)?/dp/[a-zA-Z0-9]{10}/?$"
return bool(re.match(pattern, url))
|