fourth
Browse files
app.py
CHANGED
@@ -88,14 +88,40 @@ load_vector_store = Chroma(persist_directory="stores/story_cosine", embedding_fu
|
|
88 |
# persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
|
89 |
# embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
|
90 |
|
91 |
-
# Only get the 3 most similar document from the dataset
|
92 |
-
retriever = load_vector_store.as_retriever(search_kwargs={"k":3})
|
93 |
-
|
94 |
client = InferenceClient(
|
95 |
"mistralai/Mistral-7B-Instruct-v0.1"
|
96 |
)
|
97 |
|
98 |
def generate(image, temperature=0.9, max_new_tokens=1500, top_p=0.95, repetition_penalty=1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
image_caption, gender, age, emotion = get_image_info(image)
|
100 |
print("............................................")
|
101 |
print("image_caption:", image_caption)
|
@@ -180,7 +206,7 @@ demo = gr.Interface(fn=generate,
|
|
180 |
gr.Image(sources=["upload", "webcam"], label="Upload Image", type="pil"),
|
181 |
|
182 |
gr.Slider(
|
183 |
-
label="
|
184 |
value=0.9,
|
185 |
minimum=0.0,
|
186 |
maximum=1.0,
|
@@ -190,7 +216,7 @@ demo = gr.Interface(fn=generate,
|
|
190 |
),
|
191 |
|
192 |
gr.Slider(
|
193 |
-
label="
|
194 |
value=1500,
|
195 |
minimum=0,
|
196 |
maximum=3000,
|
@@ -199,7 +225,7 @@ demo = gr.Interface(fn=generate,
|
|
199 |
info="The maximum numbers of new tokens"),
|
200 |
|
201 |
gr.Slider(
|
202 |
-
label="
|
203 |
value=0.90,
|
204 |
minimum=0.0,
|
205 |
maximum=1,
|
@@ -208,13 +234,40 @@ demo = gr.Interface(fn=generate,
|
|
208 |
info="Higher values sample more low-probability tokens",
|
209 |
),
|
210 |
gr.Slider(
|
211 |
-
label="
|
212 |
value=1.2,
|
213 |
minimum=1.0,
|
214 |
maximum=2.0,
|
215 |
step=0.05,
|
216 |
interactive=True,
|
217 |
info="Penalize repeated tokens",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
)
|
219 |
],
|
220 |
outputs=[gr.Textbox(label="Generated Story")],
|
|
|
88 |
# persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
|
89 |
# embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
|
90 |
|
|
|
|
|
|
|
91 |
client = InferenceClient(
|
92 |
"mistralai/Mistral-7B-Instruct-v0.1"
|
93 |
)
|
94 |
|
95 |
def generate(image, temperature=0.9, max_new_tokens=1500, top_p=0.95, repetition_penalty=1.0):
|
96 |
+
# load the txt file
|
97 |
+
with open("story.txt", "r") as f:
|
98 |
+
# r: read mode, reading only
|
99 |
+
state_of_the_union = f.read()
|
100 |
+
# read the file into a single string
|
101 |
+
# split the content into chunks
|
102 |
+
text_splitter = TokenTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
103 |
+
# TokenTextSplitter() can ensure the integrity of words
|
104 |
+
# each chunk to overlap with the previous chunk by 20 tokens
|
105 |
+
texts = text_splitter.split_text(state_of_the_union)
|
106 |
+
print("...........................................")
|
107 |
+
# print the first chunk
|
108 |
+
print("text[0]: ", texts[0])
|
109 |
+
# create embeddings for chunks by using bge model, and then save these vectors into chroma vector database
|
110 |
+
# use hnsw(hierarchical navigable small world) index to facilitate efficient searching
|
111 |
+
# use cosine similarity to measure similiarity.(similarity is crucial in performing similarity search.)
|
112 |
+
# hnsw: builds a graph-based index for approximate nearest neighber searches.
|
113 |
+
# hnsw is used for organizing the data into an efficient structure that supports rapid retrieval operations(speed up the search).
|
114 |
+
# cosine similarity is used for telling the hnsw algorithm how to measure the distance between vectors.
|
115 |
+
# by setting space to cosine space, the index will operate using cosine similarity to measuer the vectors' similarity.
|
116 |
+
vector_store = Chroma.from_texts(texts, embeddings, collection_metadata = {"hnsw:space":"cosine"}, persist_directory="stores/story_cosine" )
|
117 |
+
print("vector store created........................")
|
118 |
+
|
119 |
+
load_vector_store = Chroma(persist_directory="stores/story_cosine", embedding_function=embeddings)
|
120 |
+
# persist_directory="stores/story_cosine": laod the existing vector store form "stores/story_cosine"
|
121 |
+
# embedding_function=embeddings: using the bge embedding model when add the new data to the vector store
|
122 |
+
# Only get the 3 most similar document from the dataset
|
123 |
+
retriever = load_vector_store.as_retriever(search_kwargs={"k":top-k})
|
124 |
+
|
125 |
image_caption, gender, age, emotion = get_image_info(image)
|
126 |
print("............................................")
|
127 |
print("image_caption:", image_caption)
|
|
|
206 |
gr.Image(sources=["upload", "webcam"], label="Upload Image", type="pil"),
|
207 |
|
208 |
gr.Slider(
|
209 |
+
label="temperature",
|
210 |
value=0.9,
|
211 |
minimum=0.0,
|
212 |
maximum=1.0,
|
|
|
216 |
),
|
217 |
|
218 |
gr.Slider(
|
219 |
+
label="max new tokens",
|
220 |
value=1500,
|
221 |
minimum=0,
|
222 |
maximum=3000,
|
|
|
225 |
info="The maximum numbers of new tokens"),
|
226 |
|
227 |
gr.Slider(
|
228 |
+
label="top-p (nucleus sampling)",
|
229 |
value=0.90,
|
230 |
minimum=0.0,
|
231 |
maximum=1,
|
|
|
234 |
info="Higher values sample more low-probability tokens",
|
235 |
),
|
236 |
gr.Slider(
|
237 |
+
label="repetition penalty",
|
238 |
value=1.2,
|
239 |
minimum=1.0,
|
240 |
maximum=2.0,
|
241 |
step=0.05,
|
242 |
interactive=True,
|
243 |
info="Penalize repeated tokens",
|
244 |
+
),
|
245 |
+
gr.Slider(
|
246 |
+
label="chunk_size",
|
247 |
+
value=200,
|
248 |
+
minimum=50,
|
249 |
+
maximum=500,
|
250 |
+
step=1.0,
|
251 |
+
interactive=True,
|
252 |
+
info="Length of retrieved chunks",
|
253 |
+
),
|
254 |
+
gr.Slider(
|
255 |
+
label="chunk_overlap",
|
256 |
+
value=20,
|
257 |
+
minimum=0,
|
258 |
+
maximum=50,
|
259 |
+
step=1.0,
|
260 |
+
interactive=True,
|
261 |
+
info="Number of overlappong words between chunks",
|
262 |
+
),
|
263 |
+
gr.Slider(
|
264 |
+
label="top-k",
|
265 |
+
value=3,
|
266 |
+
minimum=10,
|
267 |
+
maximum=,
|
268 |
+
step=1.0,
|
269 |
+
interactive=True,
|
270 |
+
info="Number of top relevant documents to retrieve",
|
271 |
)
|
272 |
],
|
273 |
outputs=[gr.Textbox(label="Generated Story")],
|