Spaces:
Runtime error
Runtime error
ChenyuRabbitLove
commited on
Commit
•
69bac50
1
Parent(s):
a779e10
feat:add transcript db and api
Browse files- app.py +9 -9
- utils/utils.py +35 -0
app.py
CHANGED
@@ -37,9 +37,7 @@ with gr.Blocks() as demo:
|
|
37 |
)
|
38 |
upload_to_db = gr.CheckboxGroup(
|
39 |
["Upload to Database"],
|
40 |
-
label="是否上傳至資料庫",
|
41 |
-
info="將資料上傳至資料庫時,資料庫會自動建立索引,下次使用時可以直接檢索,預設為僅作這次使用",
|
42 |
-
scale=1,
|
43 |
)
|
44 |
|
45 |
with gr.Row():
|
@@ -62,6 +60,10 @@ with gr.Blocks() as demo:
|
|
62 |
video_text_input = gr.Textbox("", visible=False)
|
63 |
video_text_output = gr.Textbox("", visible=False)
|
64 |
|
|
|
|
|
|
|
|
|
65 |
# end of gradio interface
|
66 |
|
67 |
# start of workflow controller
|
@@ -90,6 +92,7 @@ with gr.Blocks() as demo:
|
|
90 |
**bot_args
|
91 |
).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
|
92 |
|
|
|
93 |
# defining workflow of clear state
|
94 |
clear_state_args = dict(
|
95 |
fn=clear_state,
|
@@ -126,12 +129,9 @@ with gr.Blocks() as demo:
|
|
126 |
**change_md_args
|
127 |
)
|
128 |
|
129 |
-
video_text_input.submit(
|
130 |
-
|
131 |
-
|
132 |
-
video_text_output,
|
133 |
-
api_name="video_bot",
|
134 |
-
)
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
demo.launch()
|
|
|
37 |
)
|
38 |
upload_to_db = gr.CheckboxGroup(
|
39 |
["Upload to Database"],
|
40 |
+
label="是否上傳至資料庫", info="將資料上傳至資料庫時,資料庫會自動建立索引,下次使用時可以直接檢索,預設為僅作這次使用", scale=1
|
|
|
|
|
41 |
)
|
42 |
|
43 |
with gr.Row():
|
|
|
60 |
video_text_input = gr.Textbox("", visible=False)
|
61 |
video_text_output = gr.Textbox("", visible=False)
|
62 |
|
63 |
+
transcript_id = gr.Textbox("", visible=False)
|
64 |
+
user_question = gr.Textbox("", visible=False)
|
65 |
+
content_output = gr.Textbox("", visible=False)
|
66 |
+
|
67 |
# end of gradio interface
|
68 |
|
69 |
# start of workflow controller
|
|
|
92 |
**bot_args
|
93 |
).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
|
94 |
|
95 |
+
|
96 |
# defining workflow of clear state
|
97 |
clear_state_args = dict(
|
98 |
fn=clear_state,
|
|
|
129 |
**change_md_args
|
130 |
)
|
131 |
|
132 |
+
video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
|
133 |
+
transcript_id.submit(search_transcript_content, [transcript_id, user_question], content_output, api_name="search_transcript_content")
|
134 |
+
|
|
|
|
|
|
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
demo.launch()
|
utils/utils.py
CHANGED
@@ -1,3 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
def clear_state(chatbot, *args):
|
2 |
return chatbot.clear_state(*args)
|
3 |
|
@@ -28,3 +38,28 @@ def bot(chatbot, *args):
|
|
28 |
|
29 |
def video_bot(video_chatbot, *args):
|
30 |
return video_chatbot.answer_question(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import openai
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
from openai.embeddings_utils import distances_from_embeddings
|
7 |
+
|
8 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
9 |
+
openai.api_key = OPENAI_API_KEY
|
10 |
+
|
11 |
def clear_state(chatbot, *args):
|
12 |
return chatbot.clear_state(*args)
|
13 |
|
|
|
38 |
|
39 |
def video_bot(video_chatbot, *args):
|
40 |
return video_chatbot.answer_question(*args)
|
41 |
+
|
42 |
+
def search_transcript_content(transcript_id, user_question):
|
43 |
+
user_q_emb = openai.Embedding.create(input=user_question, engine="text-embedding-ada-002")["data"][0]["embedding"]
|
44 |
+
|
45 |
+
|
46 |
+
transcript_db = pd.read_csv("transcript.csv")
|
47 |
+
transcript_db = transcript_db[transcript_db["uid"] == transcript_id]
|
48 |
+
|
49 |
+
transcript_db["embedding"] = (
|
50 |
+
transcript_db["embedding"].apply(eval).apply(np.array)
|
51 |
+
)
|
52 |
+
|
53 |
+
transcript_db["distance"] = distances_from_embeddings(
|
54 |
+
user_q_emb,
|
55 |
+
transcript_db["embedding"].values,
|
56 |
+
distance_metric="cosine",
|
57 |
+
)
|
58 |
+
|
59 |
+
transcript_db = transcript_db.sort_values(
|
60 |
+
by="distance", ascending=True
|
61 |
+
)
|
62 |
+
if transcript_db["distance"].values[0] > 0.2:
|
63 |
+
return "Sorry, I can't find the content."
|
64 |
+
|
65 |
+
return transcript_db.iloc[0]["text"]
|