Santiago Hincapie Potes commited on
Commit
619fe5f
1 Parent(s): 348396b

augmented generation

Browse files
Files changed (2) hide show
  1. app.py +25 -13
  2. src/deploy_utils.py +12 -1
app.py CHANGED
@@ -13,9 +13,16 @@ from src.modelling.topics.topic_extractor import (
13
  from src.modelling.topics.class_tf_idf import ClassTfidfTransformer
14
  from src import deploy_utils
15
 
16
- semantic_search_header = "What kind of product are you trying to sell?"
17
- semantic_search_placeholder = "Your magic idea goes here ✨"
18
- search_label = "Generate"
 
 
 
 
 
 
 
19
 
20
  def setup_palm():
21
  palm.configure(api_key=os.environ.get('PALM_TOKEN'))
@@ -65,7 +72,7 @@ def render_cta_link(url, label, font_awesome_icon):
65
  return st.markdown(button_code, unsafe_allow_html=True)
66
 
67
 
68
- def handler_search():
69
  relevant_products = deploy_utils.query_relevant_documents(
70
  product_model=product_model,
71
  indexer=product_indexer,
@@ -92,8 +99,10 @@ def handler_search():
92
  extracted_topics,
93
  )
94
 
95
- st.session_state.key_reviews = key_reviews
96
- print('search done')
 
 
97
 
98
 
99
  def palm_handler():
@@ -106,12 +115,15 @@ def render_search():
106
  Render the search form in the sidebar.
107
  """
108
  with st.sidebar:
109
- st.text_input(
110
- label=semantic_search_header,
111
- placeholder=semantic_search_placeholder,
112
  key="user_search_query",
113
  )
114
 
 
 
 
115
  st.text_area(
116
  label="test env",
117
  placeholder="prompt here",
@@ -119,7 +131,7 @@ def render_search():
119
  )
120
 
121
  st.button(
122
- label=search_label,
123
  key="location_search",
124
  on_click=palm_handler)
125
 
@@ -131,9 +143,9 @@ def render_search():
131
  )
132
 
133
 
134
- def render_results():
135
  # TODO: temporal
136
- st.write("# PaLM outputs")
137
  st.write(st.session_state.palm_output.result)
138
 
139
  # Execution start here!
@@ -152,4 +164,4 @@ topic_extractor, clusterer = load_uncached_models()
152
 
153
  render_search()
154
  if "palm_output" in st.session_state:
155
- render_results()
 
13
  from src.modelling.topics.class_tf_idf import ClassTfidfTransformer
14
  from src import deploy_utils
15
 
16
+
17
+ def get_prompt(title, reviews):
18
+ return f"""We are doing a marketing research analysis, in particular we are trying to understand what users thing about a particular market in order to generate tips for future sellers.
19
+ In particular, we are interesting to analyze the market for "{title}"
20
+
21
+ This is what amazon customers are saying about similar products:
22
+ {reviews}
23
+
24
+ Can you write some recomendations about how can we disrupt this market? Try to propose the necesary methodology to create a breaking product."""
25
+
26
 
27
  def setup_palm():
28
  palm.configure(api_key=os.environ.get('PALM_TOKEN'))
 
72
  return st.markdown(button_code, unsafe_allow_html=True)
73
 
74
 
75
+ def handler_review_query():
76
  relevant_products = deploy_utils.query_relevant_documents(
77
  product_model=product_model,
78
  indexer=product_indexer,
 
99
  extracted_topics,
100
  )
101
 
102
+ reviews_prompt = deploy_utils.key_reviews_to_prompt(key_reviews)
103
+ prompt = get_prompt(st.session_state.user_search_query, reviews_prompt)
104
+ st.session_state.user_prompt = prompt
105
+
106
 
107
 
108
  def palm_handler():
 
115
  Render the search form in the sidebar.
116
  """
117
  with st.sidebar:
118
+ query = st.text_input(
119
+ label="What kind of product are you trying to sell?",
120
+ placeholder="Your magic idea goes here ✨",
121
  key="user_search_query",
122
  )
123
 
124
+ if query:
125
+ handler_review_query()
126
+
127
  st.text_area(
128
  label="test env",
129
  placeholder="prompt here",
 
131
  )
132
 
133
  st.button(
134
+ label="Generate",
135
  key="location_search",
136
  on_click=palm_handler)
137
 
 
143
  )
144
 
145
 
146
+ def render_palm_results():
147
  # TODO: temporal
148
+ st.write("# ALMond recommendations")
149
  st.write(st.session_state.palm_output.result)
150
 
151
  # Execution start here!
 
164
 
165
  render_search()
166
  if "palm_output" in st.session_state:
167
+ render_palm_results()
src/deploy_utils.py CHANGED
@@ -95,4 +95,15 @@ def get_key_reviews(
95
  for idx in indices
96
  }
97
 
98
- return list(top_rated_reviews | representative_reviews)
 
 
 
 
 
 
 
 
 
 
 
 
95
  for idx in indices
96
  }
97
 
98
+ return list(top_rated_reviews | representative_reviews)
99
+
100
+
101
+ def _format_review(x):
102
+ single_line = x.split("\n")[0]
103
+ return f' - {single_line.strip()}'
104
+
105
+
106
+ def key_reviews_to_prompt(reviews):
107
+ return '\n'.join([
108
+ _format_review(i) for i in reviews
109
+ ])