Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
chore: handling user query
Browse files- app.py +27 -14
- utils_demo.py +28 -1
app.py
CHANGED
@@ -7,6 +7,8 @@ from openai import OpenAI
|
|
7 |
import os
|
8 |
import json
|
9 |
import re
|
|
|
|
|
10 |
|
11 |
anonymizer = FHEAnonymizer()
|
12 |
|
@@ -15,6 +17,17 @@ client = OpenAI(
|
|
15 |
)
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def deidentify_text(input_text):
|
19 |
anonymized_text, identified_words_with_prob = anonymizer(input_text)
|
20 |
|
@@ -74,10 +87,6 @@ def query_chatgpt(anonymized_query):
|
|
74 |
return anonymized_response, deanonymized_response
|
75 |
|
76 |
|
77 |
-
# Default demo text from the file
|
78 |
-
with open("demo_text.txt", "r") as file:
|
79 |
-
default_demo_text = file.read()
|
80 |
-
|
81 |
with open("files/original_document.txt", "r") as file:
|
82 |
original_document = file.read()
|
83 |
|
@@ -128,19 +137,23 @@ with demo:
|
|
128 |
# """
|
129 |
# )
|
130 |
|
|
|
131 |
with gr.Row():
|
132 |
-
input_text = gr.Textbox(
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
)
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
|
145 |
anonymized_text_output = gr.Textbox(label="Anonymized Text with FHE", lines=1, interactive=True)
|
146 |
|
|
|
7 |
import os
|
8 |
import json
|
9 |
import re
|
10 |
+
from utils_demo import *
|
11 |
+
from typing import List, Dict, Tuple
|
12 |
|
13 |
anonymizer = FHEAnonymizer()
|
14 |
|
|
|
17 |
)
|
18 |
|
19 |
|
20 |
+
def check_user_query_fn(user_query: str) -> Dict:
|
21 |
+
if is_user_query_valid(user_query):
|
22 |
+
# TODO: check if the query is related to our context
|
23 |
+
error_msg = ("Unable to process β: The request exceeds the length limit or falls "
|
24 |
+
"outside the scope of this document. Please refine your query.")
|
25 |
+
print(error_msg)
|
26 |
+
return {input_text: gr.update(value=error_msg)}
|
27 |
+
else:
|
28 |
+
# Collapsing Multiple Spaces
|
29 |
+
return {input_text: gr.update(value=re.sub(" +", " ", user_query))}
|
30 |
+
|
31 |
def deidentify_text(input_text):
|
32 |
anonymized_text, identified_words_with_prob = anonymizer(input_text)
|
33 |
|
|
|
87 |
return anonymized_response, deanonymized_response
|
88 |
|
89 |
|
|
|
|
|
|
|
|
|
90 |
with open("files/original_document.txt", "r") as file:
|
91 |
original_document = file.read()
|
92 |
|
|
|
137 |
# """
|
138 |
# )
|
139 |
|
140 |
+
########################## User Query Part ##########################
|
141 |
with gr.Row():
|
142 |
+
input_text = gr.Textbox(value="Who lives in Maine?", label="User query", interactive=True)
|
143 |
+
|
144 |
+
default_query_box = gr.Radio(choices=list(DEFAULT_QUERIES.keys()), label="Example Queries")
|
145 |
+
|
146 |
+
default_query_box.change(
|
147 |
+
fn=lambda default_query_box: DEFAULT_QUERIES[default_query_box],
|
148 |
+
inputs=[default_query_box],
|
149 |
+
outputs=[input_text]
|
150 |
)
|
151 |
|
152 |
+
input_text.change(
|
153 |
+
check_user_query_fn,
|
154 |
+
inputs=[input_text],
|
155 |
+
outputs=[input_text],
|
156 |
+
)
|
157 |
|
158 |
anonymized_text_output = gr.Textbox(label="Anonymized Text with FHE", lines=1, interactive=True)
|
159 |
|
utils_demo.py
CHANGED
@@ -1,6 +1,15 @@
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
|
6 |
"""
|
@@ -20,3 +29,21 @@ def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
|
|
20 |
mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy())
|
21 |
return np.array(mean_pooled_batch)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
MAX_USER_QUERY_LEN = 35
|
6 |
+
|
7 |
+
# List of example queries for easy access
|
8 |
+
DEFAULT_QUERIES = {
|
9 |
+
"Example Query 1": "Who visited microsoft.com on September 18?",
|
10 |
+
"Example Query 2": "Does Kate has drive ?",
|
11 |
+
"Example Query 3": "What phone number can be used to contact David Johnson?",
|
12 |
+
}
|
13 |
|
14 |
def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
|
15 |
"""
|
|
|
29 |
mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy())
|
30 |
return np.array(mean_pooled_batch)
|
31 |
|
32 |
+
|
33 |
+
def is_user_query_valid(user_query: str) -> bool:
|
34 |
+
"""
|
35 |
+
Check if the `user_query` is None and not empty.
|
36 |
+
Args:
|
37 |
+
user_query (str): The input text to be checked.
|
38 |
+
Returns:
|
39 |
+
bool: True if the `user_query` is None or empty, False otherwise.
|
40 |
+
"""
|
41 |
+
# If the query is not part of the default queries
|
42 |
+
is_default_query = user_query in DEFAULT_QUERIES.values()
|
43 |
+
|
44 |
+
# Check if the query exceeds the length limit
|
45 |
+
is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN
|
46 |
+
|
47 |
+
return not is_default_query and not is_exceeded_max_length
|
48 |
+
|
49 |
+
|