iarbel commited on
Commit
00f57d4
1 Parent(s): 4f7fb1b

add src files

Browse files
Files changed (4) hide show
  1. .gitignore +160 -0
  2. src/__init__.py +0 -0
  3. src/few_shot_funcs.py +143 -0
  4. src/scrape.py +98 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
src/__init__.py ADDED
File without changes
src/few_shot_funcs.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import openai
3
+ import inflect
4
+ import pandas as pd
5
+ from typing import Dict
6
+ from datasets import load_dataset
7
+ from IPython.display import display, HTML
8
+ from langchain.embeddings.openai import OpenAIEmbeddings
9
+ from langchain.vectorstores import FAISS
10
+ from langchain.vectorstores.utils import DistanceStrategy
11
+
12
+ import os
13
+
14
+ OPENAI_KEY = ''
15
+
16
+ openai.api_key = OPENAI_KEY
17
+ os.environ['OPENAI_API_KEY'] = OPENAI_KEY
18
+
19
+
20
+ # Constants
21
+ FS_COLUMNS = ['asin', 'category', 'title', 'tech_process', 'labels']
22
+ MAX_TOKENS = 700
23
+ USER_TXT = 'Write feature-bullets for an Amazon product page. ' \
24
+ 'Title: {title}. Technical details: {tech_data}.\n\n### Feature-bullets:'
25
+
26
+ # Load few-shot dataset
27
+ FS_DATASET = load_dataset('iarbel/amazon-product-data-filter', split='validation')
28
+
29
+ # Prepare Pandas DFs with the relevant columns
30
+ FS_DS = FS_DATASET.to_pandas()[FS_COLUMNS]
31
+
32
+ # Load vector store
33
+ DB = FAISS.load_local('data/vector_stores/amazon-product-embedding', OpenAIEmbeddings(),
34
+ distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
35
+
36
+
37
+ class Conversation:
38
+ """
39
+ A class to construct conversations with the ChatAPI
40
+ """
41
+ def __init__(self):
42
+ self.messages = [{'role': 'system',
43
+ 'content': 'You are a helpful assistant. Your task is to write feature-bullets for an Amazon product page.'}]
44
+
45
+ def add_message(self, role: str, content: str) -> None:
46
+ # Validate inputs
47
+ role = role.lower()
48
+ last_role = self.messages[-1]['role']
49
+ if role not in ['user', 'assistant']:
50
+ raise ValueError('Roles can be "user" or "assistant" only')
51
+ if role == 'user' and last_role not in ['system', 'assistant']:
52
+ raise ValueError('"user" message can only follow "assistant" message')
53
+ elif role == 'assistant' and last_role != 'user':
54
+ raise ValueError('"assistant" message can only follow "user" message')
55
+
56
+ message = {"role": role, "content": content}
57
+ self.messages.append(message)
58
+
59
+ def display_conversation(self) -> None:
60
+ SEP = '\n'
61
+ for message in self.messages:
62
+ if message['role'] == 'system':
63
+ display(HTML(f'<b>{message["content"]}</b>'))
64
+ elif message['role'] == 'user':
65
+ msg_align = message["content"].replace("Title:", "<br><b>Title:</b>")\
66
+ .replace("Technical details:", "<br><b>Technical details:</b>").replace("### Feature-bullets:", "<br><b>Feature-bullets:</b>")
67
+ display(HTML(f'<p style="background-color:White; color:Black; padding:5px;">{msg_align}</p>'))
68
+ else:
69
+ msg_align = message["content"].lstrip(SEP).replace(SEP, "<br><br>")
70
+ display(HTML(f'<p style="background-color:LightGray; color:Black; padding:5px;">{msg_align}</p>'))
71
+
72
+ def api_call(messages: Dict[str, str], temperature: float = 0.7, top_p: int = 1, n_responses: int = 1) -> dict:
73
+ """
74
+ A function to call the ChatAPI. Taken in a conversation, and the optional params temperature (controls randomness) and n_responses
75
+ """
76
+ params = {'model': 'gpt-3.5-turbo', 'messages': messages, 'temperature': temperature, 'max_tokens': MAX_TOKENS, 'n': n_responses, 'top_p': top_p}
77
+ response = openai.ChatCompletion.create(**params)
78
+
79
+ text = [response['choices'][i]['message']['content'] for i in range(n_responses)]
80
+ out = {'object': 'chat', 'usage': response['usage']._previous, 'text': text}
81
+ return out
82
+
83
+
84
+ class FewShotData:
85
+ def __init__(self, few_shot_df: pd.DataFrame, vector_db: FAISS):
86
+ self.few_shot_df = few_shot_df
87
+ self.vector_db = vector_db
88
+
89
+ def extract_few_shot_data(self, target_title: str, k_shot: int = 2, **db_kwargs) -> pd.DataFrame:
90
+ # Find relevant products
91
+ target_title_vector = OpenAIEmbeddings().embed_query(target_title)
92
+ similarity_list_mmr = self.vector_db.max_marginal_relevance_search_with_score_by_vector(target_title_vector, k=k_shot, **db_kwargs)
93
+ few_shot_titles = [i[0].page_content for i in similarity_list_mmr]
94
+
95
+ # Extract relevant data
96
+ few_shot_data = self.few_shot_df[self.few_shot_df['title'].isin(few_shot_titles)][['title', 'tech_process', 'labels']]
97
+ return few_shot_data
98
+
99
+ def construct_few_shot_conversation(self, target_title: str, target_tech_data: str, few_shot_data: pd.DataFrame) -> Conversation:
100
+ # Structure the few-shott data
101
+ fs_titles = few_shot_data['title'].to_list()
102
+ fs_tech_data = few_shot_data['tech_process'].to_list()
103
+ fs_labels = few_shot_data['labels'].to_list()
104
+
105
+ # Init a conversation, populate with few-shot data
106
+ conv = Conversation()
107
+ for title, tech_data, lables in zip(fs_titles, fs_tech_data, fs_labels):
108
+ conv.add_message('user', USER_TXT.format(title=title, tech_data=tech_data))
109
+ conv.add_message('assistant',lables)
110
+
111
+ # Add the final user prompt
112
+ conv.add_message('user', USER_TXT.format(title=target_title, tech_data=target_tech_data))
113
+ return conv
114
+
115
+
116
+ def return_is_are(text: str) -> str:
117
+ engine = inflect.engine()
118
+ res = 'is' if not engine.singular_noun(text) else 'are'
119
+ return res
120
+
121
+ def format_tech_as_str(tech_data):
122
+ tech_format = [f'{k} {return_is_are(k)} {v}' for k, v in tech_data.to_numpy() if k and v]
123
+ tech_str = '. '.join(tech_format)
124
+ return tech_str
125
+
126
+
127
+ def generate_data(title: str, tech_process: str, few_shot_df: pd.DataFrame, vector_db: FAISS) -> str:
128
+ fs_example = FewShotData(few_shot_df=few_shot_df, vector_db=vector_db)
129
+ fs_data = fs_example.extract_few_shot_data(target_title=title, k_shot=2)
130
+
131
+ fs_conv = fs_example.construct_few_shot_conversation(target_title=title,
132
+ target_tech_data=tech_process,
133
+ few_shot_data=fs_data)
134
+
135
+ api_res = api_call(fs_conv.messages, temperature=0.7)
136
+ feature_bullets = "## Feature-Bullets\n" + api_res['text'][0]
137
+ return feature_bullets
138
+
139
+
140
+ def check_url_structure(url: str) -> bool:
141
+ pattern = r"https://www.amazon.com(/.+)?/dp/[a-zA-Z0-9]{10}/?$"
142
+ return bool(re.match(pattern, url))
143
+
src/scrape.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import requests
3
+ from base64 import b64decode
4
+ from bs4 import BeautifulSoup
5
+ from typing import Dict
6
+
7
+ Z_KEY = ''
8
+
9
+
10
+ def zyte_call(url: str) -> bytes:
11
+ api_response = requests.post(
12
+ "https://api.zyte.com/v1/extract",
13
+ auth=(Z_KEY, ""),
14
+ json={
15
+ "url": url,
16
+ "httpResponseBody": True
17
+ },
18
+ )
19
+ http_response_body: bytes = b64decode(
20
+ api_response.json()["httpResponseBody"])
21
+ return http_response_body
22
+
23
+
24
+ def get_asin_pdp(soup: BeautifulSoup) -> Dict[str, str]:
25
+ # Get ASIN
26
+ try:
27
+ asin = soup.find('link', rel='canonical')['href'].split('/')[-1]
28
+ except TypeError:
29
+ asin = None
30
+
31
+ # Get title
32
+ search = soup.find('span', id="productTitle")
33
+ title = search.text.lstrip().rstrip() if search else None
34
+
35
+ # Get feature-bullets
36
+ search = soup.find('div', id="feature-bullets")
37
+ if search:
38
+ bullet_search = search.find_all('span', class_='a-list-item')
39
+ feature_bullets = [h.text.lstrip().rstrip() for h in bullet_search if len(bullet_search)]
40
+ # Remove unwanted bullets
41
+ feature_bullets = [b for b in feature_bullets if b != 'Make sure this fits by entering your model number.']
42
+ else:
43
+ feature_bullets = None
44
+
45
+ # Get KV, tech, A+ tables. Merge with override key hierarchy: A+ > tech > KV
46
+ kv_res = parse_kv_table(soup)
47
+ tech_res = parse_tech_table(soup)
48
+ ap_data = parse_ap_table(soup)
49
+ tech_data = {**kv_res, **tech_res, **ap_data}
50
+
51
+ res = {'asin': asin, 'title': title, 'feature_bullets': feature_bullets, 'tech_data': tech_data}
52
+ return res
53
+
54
+
55
+ def parse_kv_table(soup: BeautifulSoup) -> Dict[str, str]:
56
+ kv_res = {}
57
+ try:
58
+ search = soup.find('div', id='productOverview_feature_div')
59
+ table = search.find('table')
60
+
61
+ data = table.find_all('tr')
62
+ for d in data:
63
+ kv = d.find_all('td')
64
+ k = kv[0].text.lstrip().rstrip()
65
+ v = kv[1].text.lstrip().rstrip()
66
+ kv_res[k] = v
67
+ except AttributeError:
68
+ pass
69
+ return kv_res
70
+
71
+
72
+ def parse_tech_table(soup: BeautifulSoup) -> Dict[str, str]:
73
+ tech_res = {}
74
+ tables = soup.find_all('table', id=re.compile('productDetails_techSpec.*'))
75
+ if tables:
76
+ for tab in tables:
77
+ data = tab.find_all('tr')
78
+ for d in data:
79
+ key = d.find('th').text.lstrip().rstrip()
80
+ value = d.find('td').text.strip('\n').replace('\u200e', '').lstrip().rstrip()
81
+ tech_res[key] = value
82
+ return tech_res
83
+
84
+
85
+ def parse_ap_table(soup: BeautifulSoup) -> Dict[str, str]:
86
+ ap_res = {}
87
+ tech = soup.find_all('div', id='tech')
88
+ for div in tech:
89
+ tables = div.find_all('table')
90
+ for tab in tables:
91
+ data = tab.find_all('tr')
92
+ for d in data:
93
+ kv = d.find_all('td')
94
+ if kv:
95
+ key = kv[0].text.strip('\n').replace('\u200e', '').lstrip().rstrip()
96
+ value = kv[1].text.strip('\n').replace('\u200e', '').lstrip().rstrip()
97
+ ap_res[key] = value
98
+ return ap_res