Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
rerank model
Browse files- RAG/colpali.py +5 -4
- RAG/rag_DocumentSearcher.py +33 -33
- utilities/invoke_models.py +6 -6
- utilities/re_ranker.py +1 -1
RAG/colpali.py
CHANGED
|
@@ -66,7 +66,7 @@ runtime = boto3.client("sagemaker-runtime",aws_access_key_id=st.secrets['user_ac
|
|
| 66 |
# Prepare your payload (e.g., text-only input)
|
| 67 |
|
| 68 |
|
| 69 |
-
|
| 70 |
def call_nova(
|
| 71 |
model,
|
| 72 |
messages,
|
|
@@ -110,13 +110,14 @@ def call_nova(
|
|
| 110 |
modelId=model, body=json.dumps(request_body)
|
| 111 |
)
|
| 112 |
return response["body"]
|
|
|
|
| 113 |
def get_base64_encoded_value(media_path):
|
| 114 |
with open(media_path, "rb") as media_file:
|
| 115 |
binary_data = media_file.read()
|
| 116 |
base_64_encoded_data = base64.b64encode(binary_data)
|
| 117 |
base64_string = base_64_encoded_data.decode("utf-8")
|
| 118 |
return base64_string
|
| 119 |
-
|
| 120 |
def generate_ans(top_result,query):
|
| 121 |
print(query)
|
| 122 |
system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
|
|
@@ -146,7 +147,7 @@ def generate_ans(top_result,query):
|
|
| 146 |
print(content_text)
|
| 147 |
return content_text
|
| 148 |
|
| 149 |
-
|
| 150 |
def colpali_search_rerank(query):
|
| 151 |
# Convert to JSON string
|
| 152 |
payload = {
|
|
@@ -228,7 +229,7 @@ def colpali_search_rerank(query):
|
|
| 228 |
return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
|
| 229 |
|
| 230 |
|
| 231 |
-
|
| 232 |
def img_highlight(img,batch_queries,query_tokens):
|
| 233 |
# Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
|
| 234 |
with open(img, "rb") as f:
|
|
|
|
| 66 |
# Prepare your payload (e.g., text-only input)
|
| 67 |
|
| 68 |
|
| 69 |
+
@st.cache_resource
|
| 70 |
def call_nova(
|
| 71 |
model,
|
| 72 |
messages,
|
|
|
|
| 110 |
modelId=model, body=json.dumps(request_body)
|
| 111 |
)
|
| 112 |
return response["body"]
|
| 113 |
+
@st.cache_resource
|
| 114 |
def get_base64_encoded_value(media_path):
|
| 115 |
with open(media_path, "rb") as media_file:
|
| 116 |
binary_data = media_file.read()
|
| 117 |
base_64_encoded_data = base64.b64encode(binary_data)
|
| 118 |
base64_string = base_64_encoded_data.decode("utf-8")
|
| 119 |
return base64_string
|
| 120 |
+
@st.cache_resource
|
| 121 |
def generate_ans(top_result,query):
|
| 122 |
print(query)
|
| 123 |
system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know"
|
|
|
|
| 147 |
print(content_text)
|
| 148 |
return content_text
|
| 149 |
|
| 150 |
+
@st.cache_resource
|
| 151 |
def colpali_search_rerank(query):
|
| 152 |
# Convert to JSON string
|
| 153 |
payload = {
|
|
|
|
| 229 |
return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]
|
| 230 |
|
| 231 |
|
| 232 |
+
@st.cache_resource
|
| 233 |
def img_highlight(img,batch_queries,query_tokens):
|
| 234 |
# Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb
|
| 235 |
with open(img, "rb") as f:
|
RAG/rag_DocumentSearcher.py
CHANGED
|
@@ -12,7 +12,7 @@ headers = {"Content-Type": "application/json"}
|
|
| 12 |
host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
|
| 13 |
|
| 14 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
| 15 |
-
|
| 16 |
def query_(awsauth,inputs, session_id,search_types):
|
| 17 |
|
| 18 |
print("using index: "+st.session_state.input_index)
|
|
@@ -219,49 +219,49 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
| 219 |
hits = response_['hits']['hits']
|
| 220 |
|
| 221 |
##### GET reference tables separately like *_mm index search for images ######
|
| 222 |
-
def lazy_get_table():
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
|
| 238 |
-
|
| 239 |
|
| 240 |
-
|
| 241 |
|
| 242 |
|
| 243 |
-
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
|
| 253 |
|
| 254 |
-
|
| 255 |
-
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
|
| 263 |
|
| 264 |
-
|
| 265 |
|
| 266 |
|
| 267 |
########################### LLM Generation ########################
|
|
|
|
| 12 |
host = "https://search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com/"
|
| 13 |
|
| 14 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
| 15 |
+
@st.cache_resource
|
| 16 |
def query_(awsauth,inputs, session_id,search_types):
|
| 17 |
|
| 18 |
print("using index: "+st.session_state.input_index)
|
|
|
|
| 219 |
hits = response_['hits']['hits']
|
| 220 |
|
| 221 |
##### GET reference tables separately like *_mm index search for images ######
|
| 222 |
+
# def lazy_get_table():
|
| 223 |
+
# table_ref = []
|
| 224 |
+
# any_table_exists = False
|
| 225 |
+
# for fname in os.listdir(parent_dirname+"/split_pdf_csv"):
|
| 226 |
+
# if fname.startswith(st.session_state.input_index):
|
| 227 |
+
# any_table_exists = True
|
| 228 |
+
# break
|
| 229 |
+
# if(any_table_exists):
|
| 230 |
+
# #################### Basic Match query #################
|
| 231 |
+
# # payload_tables = {
|
| 232 |
+
# # "query": {
|
| 233 |
+
# # "bool":{
|
| 234 |
|
| 235 |
+
# # "must":{"match": {
|
| 236 |
+
# # "processed_element": question
|
| 237 |
|
| 238 |
+
# # }},
|
| 239 |
|
| 240 |
+
# # "filter":{"term":{"raw_element_type": "table"}}
|
| 241 |
|
| 242 |
|
| 243 |
+
# # }}}
|
| 244 |
|
| 245 |
+
# #################### Neural Sparse query #################
|
| 246 |
+
# payload_tables = {"query":{"neural_sparse": {
|
| 247 |
+
# "processed_element_embedding_sparse": {
|
| 248 |
+
# "query_text": question,
|
| 249 |
+
# "model_id": "fkol-ZMBTp0efWqBcO2P"
|
| 250 |
+
# }
|
| 251 |
+
# } } }
|
| 252 |
|
| 253 |
|
| 254 |
+
# r_ = requests.get(url, auth=awsauth, json=payload_tables, headers=headers)
|
| 255 |
+
# r_tables = json.loads(r_.text)
|
| 256 |
|
| 257 |
+
# for res_ in r_tables['hits']['hits']:
|
| 258 |
+
# if(res_["_source"]['raw_element_type'] == 'table'):
|
| 259 |
+
# table_ref.append({'name':res_["_source"]['table'],'text':res_["_source"]['processed_element']})
|
| 260 |
+
# if(len(table_ref) == 2):
|
| 261 |
+
# break
|
| 262 |
|
| 263 |
|
| 264 |
+
# return table_ref
|
| 265 |
|
| 266 |
|
| 267 |
########################### LLM Generation ########################
|
utilities/invoke_models.py
CHANGED
|
@@ -11,7 +11,7 @@ import streamlit as st
|
|
| 11 |
#import torch
|
| 12 |
|
| 13 |
region = 'us-east-1'
|
| 14 |
-
|
| 15 |
bedrock_runtime_client = boto3.client(
|
| 16 |
'bedrock-runtime',
|
| 17 |
aws_access_key_id=st.secrets['user_access_key'],
|
|
@@ -30,7 +30,7 @@ bedrock_runtime_client = boto3.client(
|
|
| 30 |
# max_length = 16
|
| 31 |
# num_beams = 4
|
| 32 |
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
| 33 |
-
|
| 34 |
def invoke_model(input):
|
| 35 |
response = bedrock_runtime_client.invoke_model(
|
| 36 |
body=json.dumps({
|
|
@@ -43,7 +43,7 @@ def invoke_model(input):
|
|
| 43 |
|
| 44 |
response_body = json.loads(response.get("body").read())
|
| 45 |
return response_body.get("embedding")
|
| 46 |
-
|
| 47 |
def invoke_model_mm(text,img):
|
| 48 |
body_ = {
|
| 49 |
"inputText": text,
|
|
@@ -64,7 +64,7 @@ def invoke_model_mm(text,img):
|
|
| 64 |
response_body = json.loads(response.get("body").read())
|
| 65 |
#print(response_body)
|
| 66 |
return response_body.get("embedding")
|
| 67 |
-
|
| 68 |
def invoke_llm_model(input,is_stream):
|
| 69 |
if(is_stream == False):
|
| 70 |
response = bedrock_runtime_client.invoke_model(
|
|
@@ -145,7 +145,7 @@ def invoke_llm_model(input,is_stream):
|
|
| 145 |
# stream = response.get('body')
|
| 146 |
|
| 147 |
# return stream
|
| 148 |
-
|
| 149 |
def read_from_table(file,question):
|
| 150 |
print("started table analysis:")
|
| 151 |
print("-----------------------")
|
|
@@ -181,7 +181,7 @@ def read_from_table(file,question):
|
|
| 181 |
)
|
| 182 |
agent_res = agent.invoke(question)['output']
|
| 183 |
return agent_res
|
| 184 |
-
|
| 185 |
def generate_image_captions_llm(base64_string,question):
|
| 186 |
|
| 187 |
# ant_client = Anthropic()
|
|
|
|
| 11 |
#import torch
|
| 12 |
|
| 13 |
region = 'us-east-1'
|
| 14 |
+
@st.cache_resource
|
| 15 |
bedrock_runtime_client = boto3.client(
|
| 16 |
'bedrock-runtime',
|
| 17 |
aws_access_key_id=st.secrets['user_access_key'],
|
|
|
|
| 30 |
# max_length = 16
|
| 31 |
# num_beams = 4
|
| 32 |
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
| 33 |
+
@st.cache_resource
|
| 34 |
def invoke_model(input):
|
| 35 |
response = bedrock_runtime_client.invoke_model(
|
| 36 |
body=json.dumps({
|
|
|
|
| 43 |
|
| 44 |
response_body = json.loads(response.get("body").read())
|
| 45 |
return response_body.get("embedding")
|
| 46 |
+
@st.cache_resource
|
| 47 |
def invoke_model_mm(text,img):
|
| 48 |
body_ = {
|
| 49 |
"inputText": text,
|
|
|
|
| 64 |
response_body = json.loads(response.get("body").read())
|
| 65 |
#print(response_body)
|
| 66 |
return response_body.get("embedding")
|
| 67 |
+
@st.cache_resource
|
| 68 |
def invoke_llm_model(input,is_stream):
|
| 69 |
if(is_stream == False):
|
| 70 |
response = bedrock_runtime_client.invoke_model(
|
|
|
|
| 145 |
# stream = response.get('body')
|
| 146 |
|
| 147 |
# return stream
|
| 148 |
+
@st.cache_resource
|
| 149 |
def read_from_table(file,question):
|
| 150 |
print("started table analysis:")
|
| 151 |
print("-----------------------")
|
|
|
|
| 181 |
)
|
| 182 |
agent_res = agent.invoke(question)['output']
|
| 183 |
return agent_res
|
| 184 |
+
@st.cache_resource
|
| 185 |
def generate_image_captions_llm(base64_string,question):
|
| 186 |
|
| 187 |
# ant_client = Anthropic()
|
utilities/re_ranker.py
CHANGED
|
@@ -46,7 +46,7 @@ from sentence_transformers import CrossEncoder
|
|
| 46 |
# print("Program ends.")
|
| 47 |
#########################
|
| 48 |
|
| 49 |
-
|
| 50 |
def re_rank(self_, rerank_type, search_type, question, answers):
|
| 51 |
|
| 52 |
ans = []
|
|
|
|
| 46 |
# print("Program ends.")
|
| 47 |
#########################
|
| 48 |
|
| 49 |
+
@st.cache_resource
|
| 50 |
def re_rank(self_, rerank_type, search_type, question, answers):
|
| 51 |
|
| 52 |
ans = []
|