yiyii commited on
Commit
2e99c12
1 Parent(s): e242b04
Files changed (1) hide show
  1. app.py +60 -7
app.py CHANGED
@@ -88,14 +88,40 @@ load_vector_store = Chroma(persist_directory="stores/story_cosine", embedding_fu
88
  # persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
89
  # embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
90
 
91
- # Only get the 3 most similar document from the dataset
92
- retriever = load_vector_store.as_retriever(search_kwargs={"k":3})
93
-
94
  client = InferenceClient(
95
  "mistralai/Mistral-7B-Instruct-v0.1"
96
  )
97
 
98
  def generate(image, temperature=0.9, max_new_tokens=1500, top_p=0.95, repetition_penalty=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  image_caption, gender, age, emotion = get_image_info(image)
100
  print("............................................")
101
  print("image_caption:", image_caption)
@@ -180,7 +206,7 @@ demo = gr.Interface(fn=generate,
180
  gr.Image(sources=["upload", "webcam"], label="Upload Image", type="pil"),
181
 
182
  gr.Slider(
183
- label="Temperature",
184
  value=0.9,
185
  minimum=0.0,
186
  maximum=1.0,
@@ -190,7 +216,7 @@ demo = gr.Interface(fn=generate,
190
  ),
191
 
192
  gr.Slider(
193
- label="Max new tokens",
194
  value=1500,
195
  minimum=0,
196
  maximum=3000,
@@ -199,7 +225,7 @@ demo = gr.Interface(fn=generate,
199
  info="The maximum numbers of new tokens"),
200
 
201
  gr.Slider(
202
- label="Top-p (nucleus sampling)",
203
  value=0.90,
204
  minimum=0.0,
205
  maximum=1,
@@ -208,13 +234,40 @@ demo = gr.Interface(fn=generate,
208
  info="Higher values sample more low-probability tokens",
209
  ),
210
  gr.Slider(
211
- label="Repetition penalty",
212
  value=1.2,
213
  minimum=1.0,
214
  maximum=2.0,
215
  step=0.05,
216
  interactive=True,
217
  info="Penalize repeated tokens",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  )
219
  ],
220
  outputs=[gr.Textbox(label="Generated Story")],
 
88
  # persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
89
  # embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
90
 
 
 
 
91
  client = InferenceClient(
92
  "mistralai/Mistral-7B-Instruct-v0.1"
93
  )
94
 
95
  def generate(image, temperature=0.9, max_new_tokens=1500, top_p=0.95, repetition_penalty=1.0):
96
+ # load the txt file
97
+ with open("story.txt", "r") as f:
98
+ # r: read mode, reading only
99
+ state_of_the_union = f.read()
100
+ # read the file into a single string
101
+ # split the content into chunks
102
+ text_splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
103
+ # TokenTextSplitter() can ensure the integrity of words
104
+ # each chunk to overlap with the previous chunk by 20 tokens
105
+ texts = text_splitter.split_text(state_of_the_union)
106
+ print("...........................................")
107
+ # print the first chunk
108
+ print("text[0]: ", texts[0])
109
+ # create embeddings for chunks by using bge model, and then save these vectors into chroma vector database
110
+ # use hnsw(hierarchical navigable small world) index to facilitate efficient searching
111
+ # use cosine similarity to measure similiarity.(similarity is crucial in performing similarity search.)
112
+ # hnsw: builds a graph-based index for approximate nearest neighber searches.
113
+ # hnsw is used for organizing the data into an efficient structure that supports rapid retrieval operations(speed up the search).
114
+ # cosine similarity is used for telling the hnsw algorithm how to measure the distance between vectors.
115
+ # by setting space to cosine space, the index will operate using cosine similarity to measuer the vectors' similarity.
116
+ vector_store = Chroma.from_texts(texts, embeddings, collection_metadata = {"hnsw:space":"cosine"}, persist_directory="stores/story_cosine" )
117
+ print("vector store created........................")
118
+
119
+ load_vector_store = Chroma(persist_directory="stores/story_cosine", embedding_function=embeddings)
120
+ # persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
121
+ # embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
122
+ # Only get the 3 most similar document from the dataset
123
+ retriever = load_vector_store.as_retriever(search_kwargs={"k":top-k})
124
+
125
  image_caption, gender, age, emotion = get_image_info(image)
126
  print("............................................")
127
  print("image_caption:", image_caption)
 
206
  gr.Image(sources=["upload", "webcam"], label="Upload Image", type="pil"),
207
 
208
  gr.Slider(
209
+ label="temperature",
210
  value=0.9,
211
  minimum=0.0,
212
  maximum=1.0,
 
216
  ),
217
 
218
  gr.Slider(
219
+ label="max new tokens",
220
  value=1500,
221
  minimum=0,
222
  maximum=3000,
 
225
  info="The maximum numbers of new tokens"),
226
 
227
  gr.Slider(
228
+ label="top-p (nucleus sampling)",
229
  value=0.90,
230
  minimum=0.0,
231
  maximum=1,
 
234
  info="Higher values sample more low-probability tokens",
235
  ),
236
  gr.Slider(
237
+ label="repetition penalty",
238
  value=1.2,
239
  minimum=1.0,
240
  maximum=2.0,
241
  step=0.05,
242
  interactive=True,
243
  info="Penalize repeated tokens",
244
+ ),
245
+ gr.Slider(
246
+ label="chunk_size",
247
+ value=200,
248
+ minimum=50,
249
+ maximum=500,
250
+ step=1.0,
251
+ interactive=True,
252
+ info="Length of retrieved chunks",
253
+ ),
254
+ gr.Slider(
255
+ label="chunk_overlap",
256
+ value=20,
257
+ minimum=0,
258
+ maximum=50,
259
+ step=1.0,
260
+ interactive=True,
261
+ info="Number of overlappong words between chunks",
262
+ ),
263
+ gr.Slider(
264
+ label="top-k",
265
+ value=3,
266
+ minimum=10,
267
+ maximum=,
268
+ step=1.0,
269
+ interactive=True,
270
+ info="Number of top relevant documents to retrieve",
271
  )
272
  ],
273
  outputs=[gr.Textbox(label="Generated Story")],