Spaces:
				
			
			
	
			
			
					
		Running
		
			on 
			
			CPU Upgrade
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
			on 
			
			CPU Upgrade
	rerank model
Browse files- RAG/bedrock_agent.py +1 -44
- RAG/rag_DocumentSearcher.py +0 -17
- app.py +0 -24
- pages/AI_Shopping_Assistant.py +7 -404
- pages/Semantic_Search.py +19 -342
- semantic_search/amazon_rekognition.py +2 -47
- utilities/invoke_models.py +2 -84
- utilities/re_ranker.py +0 -127
    	
        RAG/bedrock_agent.py
    CHANGED
    
    | @@ -23,8 +23,6 @@ if "inputs_" not in st.session_state: | |
| 23 |  | 
| 24 | 
             
            parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
         | 
| 25 | 
             
            region = 'us-east-1'
         | 
| 26 | 
            -
            print(region)
         | 
| 27 | 
            -
            account_id = '445083327804'
         | 
| 28 | 
             
            # setting logger
         | 
| 29 | 
             
            logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
         | 
| 30 | 
             
            logger = logging.getLogger(__name__)
         | 
| @@ -46,9 +44,6 @@ def delete_memory(): | |
| 46 | 
             
                )
         | 
| 47 |  | 
| 48 | 
             
            def query_(inputs):
         | 
| 49 | 
            -
                ## create a random id for session initiator id
         | 
| 50 | 
            -
                
         | 
| 51 | 
            -
                
         | 
| 52 | 
             
                # invoke the agent API
         | 
| 53 | 
             
                agentResponse = bedrock_agent_runtime_client.invoke_agent(
         | 
| 54 | 
             
                    inputText=inputs['shopping_query'],
         | 
| @@ -71,13 +66,6 @@ def query_(inputs): | |
| 71 | 
             
                    for event in event_stream:
         | 
| 72 | 
             
                        print("***event*********")
         | 
| 73 | 
             
                        print(event)
         | 
| 74 | 
            -
                        # if 'chunk' in event:
         | 
| 75 | 
            -
                        #     data = event['chunk']['bytes']
         | 
| 76 | 
            -
                        #     print("***chunk*********")
         | 
| 77 | 
            -
                        #     print(data)
         | 
| 78 | 
            -
                        #     logger.info(f"Final answer ->\n{data.decode('utf8')}")
         | 
| 79 | 
            -
                        #     agent_answer_ = data.decode('utf8')
         | 
| 80 | 
            -
                        #     print(agent_answer_)
         | 
| 81 | 
             
                        if 'trace' in event: 
         | 
| 82 | 
             
                            print("trace*****total*********")
         | 
| 83 | 
             
                            print(event['trace'])
         | 
| @@ -109,38 +97,7 @@ def query_(inputs): | |
| 109 | 
             
                    print(total_context)    
         | 
| 110 | 
             
                except botocore.exceptions.EventStreamError as error:
         | 
| 111 | 
             
                    raise error
         | 
| 112 | 
            -
                     | 
| 113 | 
            -
                    # query_(st.session_state.inputs_)     
         | 
| 114 | 
            -
                            
         | 
| 115 | 
            -
                        # if 'chunk' in event:
         | 
| 116 | 
            -
                        #     data = event['chunk']['bytes']
         | 
| 117 | 
            -
                        #     final_ans = data.decode('utf8')
         | 
| 118 | 
            -
                        #     print(f"Final answer ->\n{final_ans}")
         | 
| 119 | 
            -
                        #     logger.info(f"Final answer ->\n{final_ans}")
         | 
| 120 | 
            -
                        #     agent_answer = final_ans
         | 
| 121 | 
            -
                        #     end_event_received = True
         | 
| 122 | 
            -
                        #     # End event indicates that the request finished successfully
         | 
| 123 | 
            -
                        # elif 'trace' in event:
         | 
| 124 | 
            -
                        #     logger.info(json.dumps(event['trace'], indent=2))
         | 
| 125 | 
            -
                        # else:
         | 
| 126 | 
            -
                        #     raise Exception("unexpected event.", event)
         | 
| 127 | 
            -
                # except Exception as e:
         | 
| 128 | 
            -
                #     raise Exception("unexpected event.", e)
         | 
| 129 | 
             
                return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
         | 
| 130 |  | 
| 131 | 
            -
                    ####### Re-Rank ########
         | 
| 132 | 
            -
                
         | 
| 133 | 
            -
                #print("re-rank")
         | 
| 134 | 
            -
                
         | 
| 135 | 
            -
                # if(st.session_state.input_is_rerank == True and len(total_context)):
         | 
| 136 | 
            -
                #     ques = [{"question":question}]
         | 
| 137 | 
            -
                #     ans = [{"answer":total_context}]
         | 
| 138 | 
            -
                    
         | 
| 139 | 
            -
                #     total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
         | 
| 140 |  | 
| 141 | 
            -
                # llm_prompt = prompt_template.format(context=total_context[0],question=question)
         | 
| 142 | 
            -
                # output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
         | 
| 143 | 
            -
                # #print(output)
         | 
| 144 | 
            -
                # if(len(images_2)==0):
         | 
| 145 | 
            -
                #     images_2 = images
         | 
| 146 | 
            -
                # return {'text':output,'source':total_context,'image':images_2,'table':df}
         | 
|  | |
| 23 |  | 
| 24 | 
             
            parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
         | 
| 25 | 
             
            region = 'us-east-1'
         | 
|  | |
|  | |
| 26 | 
             
            # setting logger
         | 
| 27 | 
             
            logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
         | 
| 28 | 
             
            logger = logging.getLogger(__name__)
         | 
|  | |
| 44 | 
             
                )
         | 
| 45 |  | 
| 46 | 
             
            def query_(inputs):
         | 
|  | |
|  | |
|  | |
| 47 | 
             
                # invoke the agent API
         | 
| 48 | 
             
                agentResponse = bedrock_agent_runtime_client.invoke_agent(
         | 
| 49 | 
             
                    inputText=inputs['shopping_query'],
         | 
|  | |
| 66 | 
             
                    for event in event_stream:
         | 
| 67 | 
             
                        print("***event*********")
         | 
| 68 | 
             
                        print(event)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 69 | 
             
                        if 'trace' in event: 
         | 
| 70 | 
             
                            print("trace*****total*********")
         | 
| 71 | 
             
                            print(event['trace'])
         | 
|  | |
| 97 | 
             
                    print(total_context)    
         | 
| 98 | 
             
                except botocore.exceptions.EventStreamError as error:
         | 
| 99 | 
             
                    raise error
         | 
| 100 | 
            +
                    
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 101 | 
             
                return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
         | 
| 102 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        RAG/rag_DocumentSearcher.py
    CHANGED
    
    | @@ -49,7 +49,6 @@ def query_(awsauth,inputs, session_id,search_types): | |
| 49 | 
             
                images = []
         | 
| 50 |  | 
| 51 | 
             
                for hit in hits:
         | 
| 52 | 
            -
                    #context.append(hit['_source']['caption'])
         | 
| 53 | 
             
                    images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
         | 
| 54 |  | 
| 55 | 
             
                ####### SEARCH ########
         | 
| @@ -102,10 +101,6 @@ def query_(awsauth,inputs, session_id,search_types): | |
| 102 | 
             
                                }
         | 
| 103 | 
             
                                ]
         | 
| 104 |  | 
| 105 | 
            -
                    
         | 
| 106 | 
            -
                    
         | 
| 107 | 
            -
                    
         | 
| 108 | 
            -
                
         | 
| 109 | 
             
                SIZE = 5
         | 
| 110 |  | 
| 111 | 
             
                hybrid_payload = {
         | 
| @@ -159,7 +154,6 @@ def query_(awsauth,inputs, session_id,search_types): | |
| 159 |  | 
| 160 | 
             
                if('Sparse Search' in search_types):
         | 
| 161 |  | 
| 162 | 
            -
                    #print("text expansion is enabled")
         | 
| 163 | 
             
                    sparse_payload =  {  "neural_sparse": {
         | 
| 164 | 
             
                            "processed_element_embedding_sparse": {
         | 
| 165 | 
             
                                "query_text": question,
         | 
| @@ -301,7 +295,6 @@ def query_(awsauth,inputs, session_id,search_types): | |
| 301 | 
             
                        images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
         | 
| 302 |  | 
| 303 | 
             
                    idx = idx +1
         | 
| 304 | 
            -
                    #images.append(hit['_source']['image'])
         | 
| 305 |  | 
| 306 | 
             
                # if(is_table_in_result == False):
         | 
| 307 | 
             
                #     df = lazy_get_table()
         | 
| @@ -315,19 +308,9 @@ def query_(awsauth,inputs, session_id,search_types): | |
| 315 |  | 
| 316 | 
             
                total_context = context_tables + context
         | 
| 317 |  | 
| 318 | 
            -
                ####### Re-Rank ########
         | 
| 319 | 
            -
                
         | 
| 320 | 
            -
                #print("re-rank")
         | 
| 321 | 
            -
                
         | 
| 322 | 
            -
                # if(st.session_state.input_is_rerank == True and len(total_context)):
         | 
| 323 | 
            -
                #     ques = [{"question":question}]
         | 
| 324 | 
            -
                #     ans = [{"answer":total_context}]
         | 
| 325 | 
            -
                    
         | 
| 326 | 
            -
                #     total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
         | 
| 327 |  | 
| 328 | 
             
                llm_prompt = prompt_template.format(context=total_context[0],question=question)
         | 
| 329 | 
             
                output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
         | 
| 330 | 
            -
                #print(output)
         | 
| 331 | 
             
                if(len(images_2)==0):
         | 
| 332 | 
             
                    images_2 = images
         | 
| 333 | 
             
                return {'text':output,'source':total_context,'image':images_2,'table':df}
         | 
|  | |
| 49 | 
             
                images = []
         | 
| 50 |  | 
| 51 | 
             
                for hit in hits:
         | 
|  | |
| 52 | 
             
                    images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
         | 
| 53 |  | 
| 54 | 
             
                ####### SEARCH ########
         | 
|  | |
| 101 | 
             
                                }
         | 
| 102 | 
             
                                ]
         | 
| 103 |  | 
|  | |
|  | |
|  | |
|  | |
| 104 | 
             
                SIZE = 5
         | 
| 105 |  | 
| 106 | 
             
                hybrid_payload = {
         | 
|  | |
| 154 |  | 
| 155 | 
             
                if('Sparse Search' in search_types):
         | 
| 156 |  | 
|  | |
| 157 | 
             
                    sparse_payload =  {  "neural_sparse": {
         | 
| 158 | 
             
                            "processed_element_embedding_sparse": {
         | 
| 159 | 
             
                                "query_text": question,
         | 
|  | |
| 295 | 
             
                        images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
         | 
| 296 |  | 
| 297 | 
             
                    idx = idx +1
         | 
|  | |
| 298 |  | 
| 299 | 
             
                # if(is_table_in_result == False):
         | 
| 300 | 
             
                #     df = lazy_get_table()
         | 
|  | |
| 308 |  | 
| 309 | 
             
                total_context = context_tables + context
         | 
| 310 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 311 |  | 
| 312 | 
             
                llm_prompt = prompt_template.format(context=total_context[0],question=question)
         | 
| 313 | 
             
                output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
         | 
|  | |
| 314 | 
             
                if(len(images_2)==0):
         | 
| 315 | 
             
                    images_2 = images
         | 
| 316 | 
             
                return {'text':output,'source':total_context,'image':images_2,'table':df}
         | 
    	
        app.py
    CHANGED
    
    | @@ -152,28 +152,6 @@ spacer_col = st.columns(1)[0] | |
| 152 | 
             
            with spacer_col:
         | 
| 153 | 
             
                st.markdown("<div style='height: 120px;'></div>", unsafe_allow_html=True)
         | 
| 154 |  | 
| 155 | 
            -
                #st.image("/home/ubuntu/images/OS_AI_1.png", use_column_width=True)
         | 
| 156 | 
            -
            # with col_title:
         | 
| 157 | 
            -
            #     st.write("")
         | 
| 158 | 
            -
            #     st.markdown('<div class="title">OpenSearch AI demos</div>', unsafe_allow_html=True)
         | 
| 159 | 
            -
             | 
| 160 | 
            -
            # def demo_link_block(icon, title, target_page):
         | 
| 161 | 
            -
            #     st.markdown(f"""
         | 
| 162 | 
            -
            #         <a href="/{target_page}" target="_self" style="text-decoration: none;">
         | 
| 163 | 
            -
            #             <div class="demo-card">
         | 
| 164 | 
            -
            #                 <div class="demo-text">
         | 
| 165 | 
            -
            #                     <span>{icon} {title}</span>
         | 
| 166 | 
            -
            #                     <span class="demo-arrow">→</span>
         | 
| 167 | 
            -
            #                 </div>
         | 
| 168 | 
            -
            #             </div>
         | 
| 169 | 
            -
            #         </a>
         | 
| 170 | 
            -
            #     """, unsafe_allow_html=True)
         | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
            # st.write("")
         | 
| 174 | 
            -
            # demo_link_block("🔍", "AI Search", "Semantic_Search")
         | 
| 175 | 
            -
            # demo_link_block("💬","Multimodal Conversational Search", "Multimodal_Conversational_Search")
         | 
| 176 | 
            -
            # demo_link_block("🛍️","Agentic Shopping Assistant", "AI_Shopping_Assistant")
         | 
| 177 |  | 
| 178 |  | 
| 179 | 
             
            col1, col2, col3 = st.columns(3)
         | 
| @@ -225,5 +203,3 @@ st.markdown(""" | |
| 225 | 
             
                </style>
         | 
| 226 | 
             
            """, unsafe_allow_html=True)
         | 
| 227 |  | 
| 228 | 
            -
            #    <div class="card-arrow"></div>
         | 
| 229 | 
            -
                        
         | 
|  | |
| 152 | 
             
            with spacer_col:
         | 
| 153 | 
             
                st.markdown("<div style='height: 120px;'></div>", unsafe_allow_html=True)
         | 
| 154 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 155 |  | 
| 156 |  | 
| 157 | 
             
            col1, col2, col3 = st.columns(3)
         | 
|  | |
| 203 | 
             
                </style>
         | 
| 204 | 
             
            """, unsafe_allow_html=True)
         | 
| 205 |  | 
|  | |
|  | 
    	
        pages/AI_Shopping_Assistant.py
    CHANGED
    
    | @@ -33,12 +33,7 @@ import bedrock_agent | |
| 33 | 
             
            import warnings
         | 
| 34 |  | 
| 35 | 
             
            warnings.filterwarnings("ignore", category=DeprecationWarning)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
             
            st.set_page_config(
         | 
| 41 | 
            -
                #page_title="Semantic Search using OpenSearch",
         | 
| 42 | 
             
                layout="wide",
         | 
| 43 | 
             
                page_icon="images/opensearch_mark_default.png"
         | 
| 44 | 
             
            )
         | 
| @@ -47,15 +42,14 @@ USER_ICON = "images/user.png" | |
| 47 | 
             
            AI_ICON = "images/opensearch-twitter-card.png"
         | 
| 48 | 
             
            REGENERATE_ICON = "images/regenerate.png"
         | 
| 49 | 
             
            s3_bucket_ = "pdf-repo-uploads"
         | 
| 50 | 
            -
                         | 
| 51 | 
             
            polly_client = boto3.Session(
         | 
| 52 | 
             
                        region_name='us-east-1').client('polly')
         | 
| 53 |  | 
| 54 | 
             
            # Check if the user ID is already stored in the session state
         | 
| 55 | 
             
            if 'user_id' in st.session_state:
         | 
| 56 | 
             
                user_id = st.session_state['user_id']
         | 
| 57 | 
            -
                 | 
| 58 | 
            -
             | 
| 59 | 
             
            # If the user ID is not yet stored in the session state, generate a random UUID
         | 
| 60 | 
             
            else:
         | 
| 61 | 
             
                user_id = str(uuid.uuid4())
         | 
| @@ -79,9 +73,6 @@ if "questions__" not in st.session_state: | |
| 79 |  | 
| 80 | 
             
            if "answers__" not in st.session_state:
         | 
| 81 | 
             
                st.session_state.answers__ = []
         | 
| 82 | 
            -
             | 
| 83 | 
            -
            if "input_index" not in st.session_state:
         | 
| 84 | 
            -
                st.session_state.input_index = "hpijan2024hometrack"#"globalwarmingnew"#"hpijan2024hometrack_no_img_no_table"
         | 
| 85 |  | 
| 86 | 
             
            if "input_is_rerank" not in st.session_state:
         | 
| 87 | 
             
                st.session_state.input_is_rerank = True
         | 
| @@ -92,22 +83,17 @@ if "input_copali_rerank" not in st.session_state: | |
| 92 | 
             
            if "input_table_with_sql" not in st.session_state:
         | 
| 93 | 
             
                st.session_state.input_table_with_sql = False
         | 
| 94 |  | 
| 95 | 
            -
                
         | 
| 96 | 
             
            if "inputs_" not in st.session_state:
         | 
| 97 | 
             
                st.session_state.inputs_ = {}
         | 
| 98 |  | 
| 99 | 
             
            if "input_shopping_query" not in st.session_state:
         | 
| 100 | 
            -
                st.session_state.input_shopping_query="get me shoes suitable for trekking" | 
| 101 |  | 
| 102 |  | 
| 103 | 
             
            if "input_rag_searchType" not in st.session_state:
         | 
| 104 | 
             
                st.session_state.input_rag_searchType = ["Sparse Search"]
         | 
| 105 |  | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
                    
         | 
| 109 | 
             
            region = 'us-east-1'
         | 
| 110 | 
            -
            #bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
         | 
| 111 | 
             
            output = []
         | 
| 112 | 
             
            service = 'es'
         | 
| 113 |  | 
| @@ -122,48 +108,6 @@ st.markdown(""" | |
| 122 | 
             
                </style>
         | 
| 123 | 
             
                """,unsafe_allow_html=True)
         | 
| 124 |  | 
| 125 | 
            -
            ################ OpenSearch Py client #####################
         | 
| 126 | 
            -
                
         | 
| 127 | 
            -
            # credentials = boto3.Session().get_credentials()
         | 
| 128 | 
            -
            # awsauth = AWSV4SignerAuth(credentials, region, service)
         | 
| 129 | 
            -
             | 
| 130 | 
            -
            # ospy_client = OpenSearch(
         | 
| 131 | 
            -
            #     hosts = [{'host': 'search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com', 'port': 443}],
         | 
| 132 | 
            -
            #     http_auth = awsauth,
         | 
| 133 | 
            -
            #     use_ssl = True,
         | 
| 134 | 
            -
            #     verify_certs = True,
         | 
| 135 | 
            -
            #     connection_class = RequestsHttpConnection,
         | 
| 136 | 
            -
            #     pool_maxsize = 20
         | 
| 137 | 
            -
            # )
         | 
| 138 | 
            -
             | 
| 139 | 
            -
            ################# using boto3 credentials ###################
         | 
| 140 | 
            -
             | 
| 141 | 
            -
             | 
| 142 | 
            -
            # credentials = boto3.Session().get_credentials()
         | 
| 143 | 
            -
            # awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, service, session_token=credentials.token)
         | 
| 144 | 
            -
            # service = 'es'
         | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
| 147 | 
            -
            ################# using boto3 credentials ####################
         | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
            # if "input_searchType" not in st.session_state:
         | 
| 152 | 
            -
            #     st.session_state.input_searchType = "Conversational Search (RAG)"
         | 
| 153 | 
            -
             | 
| 154 | 
            -
            # if "input_temperature" not in st.session_state:
         | 
| 155 | 
            -
            #     st.session_state.input_temperature = "0.001"
         | 
| 156 | 
            -
             | 
| 157 | 
            -
            # if "input_topK" not in st.session_state:
         | 
| 158 | 
            -
            #     st.session_state.input_topK = 200
         | 
| 159 | 
            -
             | 
| 160 | 
            -
            # if "input_topP" not in st.session_state:
         | 
| 161 | 
            -
            #     st.session_state.input_topP = 0.95
         | 
| 162 | 
            -
             | 
| 163 | 
            -
            # if "input_maxTokens" not in st.session_state:
         | 
| 164 | 
            -
            #     st.session_state.input_maxTokens = 1024
         | 
| 165 | 
            -
             | 
| 166 | 
            -
             | 
| 167 | 
             
            def write_logo():
         | 
| 168 | 
             
                col1, col2, col3 = st.columns([5, 1, 5])
         | 
| 169 | 
             
                with col2:
         | 
| @@ -175,8 +119,6 @@ def write_top_bar(): | |
| 175 | 
             
                    st.page_link("app.py", label=":orange[Home]", icon="🏠")
         | 
| 176 | 
             
                    st.header("AI Shopping assistant",divider='rainbow')
         | 
| 177 |  | 
| 178 | 
            -
                    #st.image(AI_ICON, use_column_width='always')
         | 
| 179 | 
            -
                
         | 
| 180 | 
             
                with col2:
         | 
| 181 | 
             
                    st.write("")
         | 
| 182 | 
             
                    st.write("")
         | 
| @@ -193,17 +135,10 @@ if clear: | |
| 193 | 
             
                st.session_state.input_shopping_query=""
         | 
| 194 | 
             
                st.session_state.session_id_ = str(uuid.uuid1())
         | 
| 195 | 
             
                bedrock_agent.delete_memory()
         | 
| 196 | 
            -
             | 
| 197 | 
            -
                # st.session_state.input_temperature = "0.001"
         | 
| 198 | 
            -
                # st.session_state.input_topK = 200
         | 
| 199 | 
            -
                # st.session_state.input_topP = 0.95
         | 
| 200 | 
            -
                # st.session_state.input_maxTokens = 1024
         | 
| 201 |  | 
| 202 |  | 
| 203 | 
             
            def handle_input():
         | 
| 204 | 
            -
                print("Question: "+st.session_state.input_shopping_query)
         | 
| 205 | 
            -
                print("-----------")
         | 
| 206 | 
            -
                print("\n\n")
         | 
| 207 | 
             
                if(st.session_state.input_shopping_query==''):
         | 
| 208 | 
             
                    return ""
         | 
| 209 | 
             
                inputs = {}
         | 
| @@ -212,10 +147,6 @@ def handle_input(): | |
| 212 | 
             
                        inputs[key.removeprefix('input_')] = st.session_state[key]
         | 
| 213 | 
             
                st.session_state.inputs_ = inputs
         | 
| 214 |  | 
| 215 | 
            -
                #######
         | 
| 216 | 
            -
                
         | 
| 217 | 
            -
                
         | 
| 218 | 
            -
                #st.write(inputs) 
         | 
| 219 | 
             
                question_with_id = {
         | 
| 220 | 
             
                    'question': inputs["shopping_query"],
         | 
| 221 | 
             
                    'id': len(st.session_state.questions__)
         | 
| @@ -234,30 +165,6 @@ def handle_input(): | |
| 234 | 
             
                st.session_state.input_shopping_query=""
         | 
| 235 |  | 
| 236 |  | 
| 237 | 
            -
                
         | 
| 238 | 
            -
            # search_type = st.selectbox('Select the Search type',
         | 
| 239 | 
            -
            #     ('Conversational Search (RAG)',
         | 
| 240 | 
            -
            #     'OpenSearch vector search', 
         | 
| 241 | 
            -
            #     'LLM Text Generation'
         | 
| 242 | 
            -
            #     ),
         | 
| 243 | 
            -
               
         | 
| 244 | 
            -
            #     key = 'input_searchType',
         | 
| 245 | 
            -
            #     help = "Select the type of retriever\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers_)"
         | 
| 246 | 
            -
            #     )
         | 
| 247 | 
            -
             | 
| 248 | 
            -
            # col1, col2, col3, col4 = st.columns(4)
         | 
| 249 | 
            -
                
         | 
| 250 | 
            -
            # with col1:
         | 
| 251 | 
            -
            #     st.text_input('Temperature', value = "0.001", placeholder='LLM Temperature', key = 'input_temperature',help = "Set the temperature of the Large Language model. \n Note: 1. Set this to values lower to 1 in the order of 0.001, 0.0001, such low values reduces hallucination and creativity in the LLM response; 2. This applies only when LLM is a part of the retriever pipeline")
         | 
| 252 | 
            -
            # with col2:
         | 
| 253 | 
            -
            #     st.number_input('Top K', value = 200, placeholder='Top K', key = 'input_topK', step = 50, help = "This limits the LLM's predictions to the top k most probable tokens at each step of generation, this applies only when LLM is a prt of the retriever pipeline")
         | 
| 254 | 
            -
            # with col3:
         | 
| 255 | 
            -
            #     st.number_input('Top P', value = 0.95, placeholder='Top P', key = 'input_topP', step = 0.05, help = "This sets a threshold probability and selects the top tokens whose cumulative probability exceeds the threshold while the tokens are generated by the LLM")
         | 
| 256 | 
            -
            # with col4:
         | 
| 257 | 
            -
            #     st.number_input('Max Output Tokens', value = 500, placeholder='Max Output Tokens', key = 'input_maxTokens', step = 100, help = "This decides the total number of tokens generated as the final response. Note: Values greater than 1000 takes longer response time")
         | 
| 258 | 
            -
             | 
| 259 | 
            -
            # st.markdown('---')
         | 
| 260 | 
            -
             | 
| 261 |  | 
| 262 | 
             
            def write_user_message(md):
         | 
| 263 | 
             
                col1, col2 = st.columns([3,97])
         | 
| @@ -265,8 +172,6 @@ def write_user_message(md): | |
| 265 | 
             
                with col1:
         | 
| 266 | 
             
                    st.image(USER_ICON, use_column_width='always')
         | 
| 267 | 
             
                with col2:
         | 
| 268 | 
            -
                    #st.warning(md['question'])
         | 
| 269 | 
            -
             | 
| 270 | 
             
                    st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
         | 
| 271 |  | 
| 272 |  | 
| @@ -283,18 +188,9 @@ def render_answer(question,answer,index): | |
| 283 | 
             
                    ans_ = answer['answer']
         | 
| 284 | 
             
                    span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
         | 
| 285 | 
             
                    st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
         | 
| 286 | 
            -
                    print("answer['source']")
         | 
| 287 | 
            -
                    print("-------------")
         | 
| 288 | 
            -
                    print(answer['source'])
         | 
| 289 | 
            -
                    print("-------------")
         | 
| 290 | 
            -
                    print(answer['last_tool'])
         | 
| 291 | 
             
                    if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
         | 
| 292 | 
             
                        use_interim_results = True
         | 
| 293 | 
             
                        src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
         | 
| 294 | 
            -
                    print("src_dict")
         | 
| 295 | 
            -
                    print("-------------")
         | 
| 296 | 
            -
                    print(src_dict)
         | 
| 297 | 
            -
                    #if("get_relevant_items_for_text" in src_dict):
         | 
| 298 | 
             
                    if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
         | 
| 299 | 
             
                        key_ = answer['last_tool']['name']
         | 
| 300 |  | 
| @@ -310,9 +206,7 @@ def render_answer(question,answer,index): | |
| 310 | 
             
                            if(index ==1):
         | 
| 311 | 
             
                                with img_col2:
         | 
| 312 | 
             
                                    st.image(resizedImg,use_column_width = True,caption = item['title'])
         | 
| 313 | 
            -
                                     | 
| 314 | 
            -
                            
         | 
| 315 | 
            -
                            
         | 
| 316 | 
             
                    if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):   
         | 
| 317 | 
             
                        st.write("<br>",unsafe_allow_html = True)
         | 
| 318 | 
             
                        gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
         | 
| @@ -328,143 +222,17 @@ def render_answer(question,answer,index): | |
| 328 | 
             
                        with gen_img_col1:
         | 
| 329 | 
             
                            st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
         | 
| 330 | 
             
                        st.write("<br>",unsafe_allow_html = True)
         | 
| 331 | 
            -
             | 
| 332 | 
            -
             | 
| 333 | 
            -
                        
         | 
| 334 | 
            -
                         
         | 
| 335 | 
            -
                   
         | 
| 336 | 
            -
                    
         | 
| 337 | 
            -
                    
         | 
| 338 | 
            -
                    # def stream_():
         | 
| 339 | 
            -
                    #     #use for streaming response on the client side
         | 
| 340 | 
            -
                    #     for word in ans_.split(" "):
         | 
| 341 | 
            -
                    #         yield word + " "
         | 
| 342 | 
            -
                    #         time.sleep(0.04)
         | 
| 343 | 
            -
                    #     #use for streaming response from Llm directly
         | 
| 344 | 
            -
                    #     if(isinstance(ans_,botocore.eventstream.EventStream)):
         | 
| 345 | 
            -
                    #         for event in ans_:
         | 
| 346 | 
            -
                    #             chunk = event.get('chunk')
         | 
| 347 | 
            -
                                
         | 
| 348 | 
            -
                    #             if chunk:
         | 
| 349 | 
            -
                                    
         | 
| 350 | 
            -
                    #                 chunk_obj = json.loads(chunk.get('bytes').decode())
         | 
| 351 | 
            -
                                    
         | 
| 352 | 
            -
                    #                 if('content_block' in chunk_obj or ('delta' in chunk_obj and 'text' in chunk_obj['delta'])):
         | 
| 353 | 
            -
                    #                     key_ = list(chunk_obj.keys())[2]
         | 
| 354 | 
            -
                    #                     text = chunk_obj[key_]['text']
         | 
| 355 | 
            -
                                        
         | 
| 356 | 
            -
                    #                     clear_output(wait=True)
         | 
| 357 | 
            -
                    #                     output.append(text)
         | 
| 358 | 
            -
                    #                     yield text
         | 
| 359 | 
            -
                    #                     time.sleep(0.04)
         | 
| 360 | 
            -
                        
         | 
| 361 | 
            -
                            
         | 
| 362 | 
            -
                    
         | 
| 363 | 
            -
                    # if(index == len(st.session_state.questions_)):
         | 
| 364 | 
            -
                    #     st.write_stream(stream_)
         | 
| 365 | 
            -
                    #     if(isinstance(st.session_state.answers_[index-1]['answer'],botocore.eventstream.EventStream)):
         | 
| 366 | 
            -
                    #         st.session_state.answers_[index-1]['answer'] = "".join(output)
         | 
| 367 | 
            -
                    # else:
         | 
| 368 | 
            -
                    #     st.write(ans_)
         | 
| 369 | 
            -
                    
         | 
| 370 | 
            -
             | 
| 371 | 
            -
                    # polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
         | 
| 372 | 
            -
                    #                 OutputFormat='ogg_vorbis', 
         | 
| 373 | 
            -
                    #                 Text = ans_,
         | 
| 374 | 
            -
                    #                 Engine = 'neural')
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                    # audio_col1, audio_col2 = st.columns([50,50])
         | 
| 377 | 
            -
                    # with audio_col1:
         | 
| 378 | 
            -
                    #     st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
         | 
| 379 | 
            -
                            
         | 
| 380 | 
            -
                    
         | 
| 381 | 
            -
                    
         | 
| 382 | 
            -
                    #st.markdown("<div style='font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;border-radius: 10px;'>"+ans_+"</div>", unsafe_allow_html = True)
         | 
| 383 | 
            -
                #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Relevant images from the document :</b></div>", unsafe_allow_html = True)
         | 
| 384 | 
            -
                #st.write("")
         | 
| 385 | 
             
                colu1,colu2,colu3 = st.columns([4,82,20])
         | 
| 386 | 
             
                if(answer['source']!={}):
         | 
| 387 | 
             
                    with colu2:
         | 
| 388 | 
             
                        with st.expander("Agent Traces:"):
         | 
| 389 | 
             
                            st.write(answer['source'])
         | 
| 390 | 
            -
             | 
| 391 | 
            -
                    #             if(len(res_img)>0):
         | 
| 392 | 
            -
                    #                 with st.expander("Images:"):
         | 
| 393 | 
            -
                    #                     col3,col4,col5 = st.columns([33,33,33])
         | 
| 394 | 
            -
                    #                     cols = [col3,col4]
         | 
| 395 | 
            -
                    #                     idx = 0
         | 
| 396 | 
            -
                    #                     #print(res_img)
         | 
| 397 | 
            -
                    #                     for img_ in res_img:
         | 
| 398 | 
            -
                    #                         if(img_['file'].lower()!='none' and idx < 2):
         | 
| 399 | 
            -
                    #                             img = img_['file'].split(".")[0]
         | 
| 400 | 
            -
                    #                             caption = img_['caption']
         | 
| 401 | 
            -
                                                
         | 
| 402 | 
            -
                    #                             with cols[idx]:
         | 
| 403 | 
            -
                                                    
         | 
| 404 | 
            -
                    #                                 st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
         | 
| 405 | 
            -
                    #                                 #st.write(caption)
         | 
| 406 | 
            -
                    #                             idx = idx+1
         | 
| 407 | 
            -
                    #             #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
         | 
| 408 | 
            -
                    #             if(len(answer["table"] )>0):
         | 
| 409 | 
            -
                    #                 with st.expander("Table:"):
         | 
| 410 | 
            -
                    #                     df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
         | 
| 411 | 
            -
                    #                     df.fillna(method='pad', inplace=True)
         | 
| 412 | 
            -
                    #                     st.table(df)
         | 
| 413 | 
            -
                    #             with st.expander("Raw sources:"):
         | 
| 414 | 
            -
                    #                 st.write(answer["source"])
         | 
| 415 | 
            -
                                
         | 
| 416 | 
            -
                        
         | 
| 417 | 
            -
                        
         | 
| 418 | 
            -
                    # with col_3:
         | 
| 419 | 
            -
                        
         | 
| 420 | 
            -
                    #     #st.markdown("<div style='color:#e28743;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 5px;'><b>"+",".join(st.session_state.input_rag_searchType)+"</b></div>", unsafe_allow_html = True)
         | 
| 421 | 
            -
                        
         | 
| 422 | 
            -
                    
         | 
| 423 | 
            -
                        
         | 
| 424 | 
            -
                    #     if(index == len(st.session_state.questions_)):
         | 
| 425 | 
            -
             | 
| 426 | 
            -
                    #         rdn_key = ''.join([random.choice(string.ascii_letters)
         | 
| 427 | 
            -
                    #                           for _ in range(10)])
         | 
| 428 | 
            -
                    #         currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
         | 
| 429 | 
            -
                    #         oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
         | 
| 430 | 
            -
                    #         #print("changing values-----------------")
         | 
| 431 | 
            -
                    #         def on_button_click():
         | 
| 432 | 
            -
                    #             # print("button clicked---------------")
         | 
| 433 | 
            -
                    #             # print(currentValue)
         | 
| 434 | 
            -
                    #             # print(oldValue)
         | 
| 435 | 
            -
                    #             if(currentValue!=oldValue or 1==1): 
         | 
| 436 | 
            -
                    #                 #print("----------regenerate----------------")
         | 
| 437 | 
            -
                    #                 st.session_state.input_query = st.session_state.questions_[-1]["question"]
         | 
| 438 | 
            -
                    #                 st.session_state.answers_.pop()
         | 
| 439 | 
            -
                    #                 st.session_state.questions_.pop()
         | 
| 440 | 
            -
                                    
         | 
| 441 | 
            -
                    #                 handle_input()
         | 
| 442 | 
            -
                    #                 with placeholder.container():
         | 
| 443 | 
            -
                    #                     render_all()
         | 
| 444 | 
            -
             | 
| 445 | 
            -
                    #         if("currentValue"  in st.session_state):
         | 
| 446 | 
            -
                    #             del st.session_state["currentValue"]
         | 
| 447 | 
            -
             | 
| 448 | 
            -
                    #         try:
         | 
| 449 | 
            -
                    #             del regenerate
         | 
| 450 | 
            -
                    #         except:
         | 
| 451 | 
            -
                    #             pass  
         | 
| 452 | 
            -
             | 
| 453 | 
            -
                    #         #print("------------------------")
         | 
| 454 | 
            -
                    #         #print(st.session_state)
         | 
| 455 | 
            -
             | 
| 456 | 
            -
                    #         placeholder__ = st.empty()
         | 
| 457 | 
            -
                            
         | 
| 458 | 
            -
                    #         placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
         | 
| 459 |  | 
| 460 | 
             
            #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
         | 
| 461 | 
             
            def write_chat_message(md, q,index):
         | 
| 462 | 
            -
                #res_img = md['image']
         | 
| 463 | 
            -
                #st.session_state['session_id'] = res['session_id']   to be added in memory
         | 
| 464 | 
             
                chat = st.container()
         | 
| 465 | 
             
                with chat:
         | 
| 466 | 
            -
                    #print("st.session_state.input_index------------------")
         | 
| 467 | 
            -
                    #print(st.session_state.input_index)
         | 
| 468 | 
             
                    render_answer(q,md,index)
         | 
| 469 |  | 
| 470 | 
             
            def render_all():  
         | 
| @@ -480,173 +248,8 @@ with placeholder.container(): | |
| 480 |  | 
| 481 | 
             
            st.markdown("")
         | 
| 482 | 
             
            col_2, col_3 = st.columns([75,20])
         | 
| 483 | 
            -
             | 
| 484 | 
            -
            # with col_1:
         | 
| 485 | 
            -
            #     st.markdown("<p style='padding:0px 0px 0px 0px; color:#FF9900;font-size:120%'><b>Ask:</b></p>",unsafe_allow_html=True, help = 'Enter the questions and click on "GO"')
         | 
| 486 | 
            -
                
         | 
| 487 | 
             
            with col_2:
         | 
| 488 | 
            -
                #st.markdown("")
         | 
| 489 | 
             
                input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
         | 
| 490 | 
             
            with col_3:
         | 
| 491 | 
            -
                #hidden = st.button("RUN",disabled=True,key = "hidden")
         | 
| 492 | 
            -
                # audio_value = st.audio_input("Record a voice message")
         | 
| 493 | 
            -
                # print(audio_value)
         | 
| 494 | 
             
                play = st.button("Go",on_click=handle_input,key = "play")
         | 
| 495 | 
            -
            #with st.sidebar:
         | 
| 496 | 
            -
                # st.page_link("/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/app.py", label=":orange[Home]", icon="🏠")
         | 
| 497 | 
            -
                # st.subheader(":blue[Sample Data]")
         | 
| 498 | 
            -
                # coln_1,coln_2 = st.columns([70,30])
         | 
| 499 | 
            -
                # # index_select = st.radio("Choose one index",["UK Housing","Covid19 impacts on Ireland","Environmental Global Warming","BEIR Research"],
         | 
| 500 | 
            -
                # #                         captions = ['[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)',
         | 
| 501 | 
            -
                # #                                     '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)',
         | 
| 502 | 
            -
                # #                                     '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)',
         | 
| 503 | 
            -
                # #                                     '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)'],
         | 
| 504 | 
            -
                # #                         key="input_rad_index")
         | 
| 505 | 
            -
                # with coln_1:
         | 
| 506 | 
            -
                #     index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
         | 
| 507 | 
            -
                # with coln_2:
         | 
| 508 | 
            -
                #     st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
         | 
| 509 | 
            -
                #     st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
         | 
| 510 | 
            -
                #     st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
         | 
| 511 | 
            -
                #     st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
         | 
| 512 | 
            -
                #     #st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)")
         | 
| 513 | 
            -
                # st.markdown("""
         | 
| 514 | 
            -
                # <style>
         | 
| 515 | 
            -
                # [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
         | 
| 516 | 
            -
                #     gap: 0rem;
         | 
| 517 | 
            -
                # }
         | 
| 518 | 
            -
                # [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
         | 
| 519 | 
            -
                #     gap: 0rem;
         | 
| 520 | 
            -
                # }
         | 
| 521 | 
            -
                # </style>
         | 
| 522 | 
            -
                # """,unsafe_allow_html=True)   
         | 
| 523 | 
            -
                # # Initialize boto3 to use the S3 client.
         | 
| 524 | 
            -
                # s3_client = boto3.resource('s3')
         | 
| 525 | 
            -
                # bucket=s3_client.Bucket(s3_bucket_)
         | 
| 526 | 
            -
             | 
| 527 | 
            -
                # objects = bucket.objects.filter(Prefix="sample_pdfs/")
         | 
| 528 | 
            -
                # urls = []
         | 
| 529 | 
            -
             | 
| 530 | 
            -
                # client = boto3.client('s3')
         | 
| 531 | 
            -
             | 
| 532 | 
            -
                # for obj in objects:
         | 
| 533 | 
            -
                #     if obj.key.endswith('.pdf'): 
         | 
| 534 | 
            -
             | 
| 535 | 
            -
                #         # Generate the S3 presigned URL
         | 
| 536 | 
            -
                #         s3_presigned_url = client.generate_presigned_url(
         | 
| 537 | 
            -
                #             ClientMethod='get_object',
         | 
| 538 | 
            -
                #             Params={
         | 
| 539 | 
            -
                #                 'Bucket': s3_bucket_,
         | 
| 540 | 
            -
                #                 'Key': obj.key
         | 
| 541 | 
            -
                #             },
         | 
| 542 | 
            -
                #             ExpiresIn=3600
         | 
| 543 | 
            -
                #         )
         | 
| 544 | 
            -
             | 
| 545 | 
            -
                #         # Print the created S3 presigned URL
         | 
| 546 | 
            -
                #         print(s3_presigned_url)
         | 
| 547 | 
            -
                #         urls.append(s3_presigned_url)
         | 
| 548 | 
            -
                #         #st.write("["+obj.key.split('/')[1]+"]("+s3_presigned_url+")")
         | 
| 549 | 
            -
                #         st.link_button(obj.key.split('/')[1], s3_presigned_url)
         | 
| 550 | 
            -
                
         | 
| 551 | 
            -
                
         | 
| 552 | 
            -
                # st.subheader(":blue[Your multi-modal documents]")
         | 
| 553 | 
            -
                # pdf_doc_ = st.file_uploader(
         | 
| 554 | 
            -
                #     "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
         | 
| 555 | 
            -
                                
         | 
| 556 | 
            -
                            
         | 
| 557 | 
            -
                # pdf_docs = [pdf_doc_]
         | 
| 558 | 
            -
                # if st.button("Process"):
         | 
| 559 | 
            -
                #     with st.spinner("Processing"):
         | 
| 560 | 
            -
                #         if os.path.isdir(parent_dirname+"/pdfs") == False:
         | 
| 561 | 
            -
                #             os.mkdir(parent_dirname+"/pdfs")
         | 
| 562 | 
            -
                        
         | 
| 563 | 
            -
                #         for pdf_doc in pdf_docs:
         | 
| 564 | 
            -
                #             print(type(pdf_doc))
         | 
| 565 | 
            -
                #             pdf_doc_name = (pdf_doc.name).replace(" ","_")
         | 
| 566 | 
            -
                #             with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f: 
         | 
| 567 | 
            -
                #                 f.write(pdf_doc.getbuffer())  
         | 
| 568 | 
            -
                                
         | 
| 569 | 
            -
                #             request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
         | 
| 570 | 
            -
                #             # if(st.session_state.input_copali_rerank):
         | 
| 571 | 
            -
                #             #     copali.process_doc(request_)
         | 
| 572 | 
            -
                #             # else:
         | 
| 573 | 
            -
                #             rag_DocumentLoader.load_docs(request_)
         | 
| 574 | 
            -
                #             print('lambda done')
         | 
| 575 | 
            -
                #     st.success('you can start searching on your PDF')
         | 
| 576 | 
            -
                    
         | 
| 577 | 
            -
                # ############## haystach demo temporary addition ############    
         | 
| 578 | 
            -
                # # st.subheader(":blue[Multimodality]")
         | 
| 579 | 
            -
                # # colu1,colu2 = st.columns([50,50])
         | 
| 580 | 
            -
                # # with colu1:
         | 
| 581 | 
            -
                # #     in_images = st.toggle('Images', key = 'in_images', disabled = False)
         | 
| 582 | 
            -
                # # with colu2:
         | 
| 583 | 
            -
                # #     in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)   
         | 
| 584 | 
            -
                # # if(in_tables):
         | 
| 585 | 
            -
                # #     st.session_state.input_table_with_sql = True
         | 
| 586 | 
            -
                # # else:
         | 
| 587 | 
            -
                # #     st.session_state.input_table_with_sql = False
         | 
| 588 | 
            -
                    
         | 
| 589 | 
            -
                #  ############## haystach demo temporary addition ############       
         | 
| 590 | 
            -
                # if(pdf_doc_ is None or pdf_doc_ == ""):
         | 
| 591 | 
            -
                #     if(index_select == "Global Warming stats"):
         | 
| 592 | 
            -
                #         st.session_state.input_index = "globalwarmingnew"
         | 
| 593 | 
            -
                #     if(index_select == "Covid19 impacts on Ireland"):
         | 
| 594 | 
            -
                #         st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
         | 
| 595 | 
            -
                #     if(index_select == "BEIR"):
         | 
| 596 | 
            -
                #         st.session_state.input_index = "2104"
         | 
| 597 | 
            -
                #     if(index_select == "UK Housing"):
         | 
| 598 | 
            -
                #         st.session_state.input_index = "hpijan2024hometrack"
         | 
| 599 | 
            -
                #         # if(in_images == True and in_tables == True):
         | 
| 600 | 
            -
                #         #     st.session_state.input_index = "hpijan2024hometrack"
         | 
| 601 | 
            -
                #         # else:
         | 
| 602 | 
            -
                #         #     if(in_images == True and in_tables == False):
         | 
| 603 | 
            -
                #         #         st.session_state.input_index = "hpijan2024hometrackno_table"
         | 
| 604 | 
            -
                #         #     else:
         | 
| 605 | 
            -
                #         #         if(in_images == False and in_tables == True):
         | 
| 606 | 
            -
                #         #             st.session_state.input_index = "hpijan2024hometrackno_images"
         | 
| 607 | 
            -
                #         #         else:   
         | 
| 608 | 
            -
                #         #             st.session_state.input_index = "hpijan2024hometrack_no_img_no_table"
         | 
| 609 | 
            -
                            
         | 
| 610 | 
            -
                                
         | 
| 611 | 
            -
                # # if(in_images):
         | 
| 612 | 
            -
                # #     st.session_state.input_include_images = True
         | 
| 613 | 
            -
                # # else:
         | 
| 614 | 
            -
                # #     st.session_state.input_include_images = False
         | 
| 615 | 
            -
                # # if(in_tables):
         | 
| 616 | 
            -
                # #     st.session_state.input_include_tables = True
         | 
| 617 | 
            -
                # # else:
         | 
| 618 | 
            -
                # #     st.session_state.input_include_tables = False
         | 
| 619 | 
            -
                
         | 
| 620 | 
            -
                # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
         | 
| 621 | 
            -
                # if(custom_index!=""):
         | 
| 622 | 
            -
                #     st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
         | 
| 623 | 
            -
                
         | 
| 624 | 
            -
                
         | 
| 625 | 
            -
                
         | 
| 626 | 
            -
                # st.subheader(":blue[Retriever]")
         | 
| 627 | 
            -
                # search_type = st.multiselect('Select the Retriever(s)',
         | 
| 628 | 
            -
                # ['Keyword Search',
         | 
| 629 | 
            -
                # 'Vector Search', 
         | 
| 630 | 
            -
                # 'Sparse Search',
         | 
| 631 | 
            -
                # ],
         | 
| 632 | 
            -
                # ['Sparse Search'],
         | 
| 633 | 
            -
             | 
| 634 | 
            -
                # key = 'input_rag_searchType',
         | 
| 635 | 
            -
                # help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
         | 
| 636 | 
            -
                # )
         | 
| 637 | 
            -
                
         | 
| 638 | 
            -
                # re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
         | 
| 639 | 
            -
                    
         | 
| 640 | 
            -
                # if(re_rank):
         | 
| 641 | 
            -
                #     st.session_state.input_is_rerank = True
         | 
| 642 | 
            -
                # else:
         | 
| 643 | 
            -
                #     st.session_state.input_is_rerank = False
         | 
| 644 | 
            -
                    
         | 
| 645 | 
            -
                # # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
         | 
| 646 | 
            -
                    
         | 
| 647 | 
            -
                # # if(copali_rerank):
         | 
| 648 | 
            -
                # #     st.session_state.input_copali_rerank = True
         | 
| 649 | 
            -
                # # else:
         | 
| 650 | 
            -
                # #     st.session_state.input_copali_rerank = False
         | 
| 651 | 
            -
                    
         | 
| 652 | 
            -
                    
         | 
|  | |
| 33 | 
             
            import warnings
         | 
| 34 |  | 
| 35 | 
             
            warnings.filterwarnings("ignore", category=DeprecationWarning)
         | 
|  | |
|  | |
|  | |
|  | |
| 36 | 
             
            st.set_page_config(
         | 
|  | |
| 37 | 
             
                layout="wide",
         | 
| 38 | 
             
                page_icon="images/opensearch_mark_default.png"
         | 
| 39 | 
             
            )
         | 
|  | |
| 42 | 
             
            AI_ICON = "images/opensearch-twitter-card.png"
         | 
| 43 | 
             
            REGENERATE_ICON = "images/regenerate.png"
         | 
| 44 | 
             
            s3_bucket_ = "pdf-repo-uploads"
         | 
| 45 | 
            +
                        
         | 
| 46 | 
             
            polly_client = boto3.Session(
         | 
| 47 | 
             
                        region_name='us-east-1').client('polly')
         | 
| 48 |  | 
| 49 | 
             
            # Check if the user ID is already stored in the session state
         | 
| 50 | 
             
            if 'user_id' in st.session_state:
         | 
| 51 | 
             
                user_id = st.session_state['user_id']
         | 
| 52 | 
            +
                
         | 
|  | |
| 53 | 
             
            # If the user ID is not yet stored in the session state, generate a random UUID
         | 
| 54 | 
             
            else:
         | 
| 55 | 
             
                user_id = str(uuid.uuid4())
         | 
|  | |
| 73 |  | 
| 74 | 
             
            if "answers__" not in st.session_state:
         | 
| 75 | 
             
                st.session_state.answers__ = []
         | 
|  | |
|  | |
|  | |
| 76 |  | 
| 77 | 
             
            if "input_is_rerank" not in st.session_state:
         | 
| 78 | 
             
                st.session_state.input_is_rerank = True
         | 
|  | |
| 83 | 
             
            if "input_table_with_sql" not in st.session_state:
         | 
| 84 | 
             
                st.session_state.input_table_with_sql = False
         | 
| 85 |  | 
|  | |
| 86 | 
             
            if "inputs_" not in st.session_state:
         | 
| 87 | 
             
                st.session_state.inputs_ = {}
         | 
| 88 |  | 
| 89 | 
             
            if "input_shopping_query" not in st.session_state:
         | 
| 90 | 
            +
                st.session_state.input_shopping_query="get me shoes suitable for trekking"
         | 
| 91 |  | 
| 92 |  | 
| 93 | 
             
            if "input_rag_searchType" not in st.session_state:
         | 
| 94 | 
             
                st.session_state.input_rag_searchType = ["Sparse Search"]
         | 
| 95 |  | 
|  | |
|  | |
|  | |
| 96 | 
             
            region = 'us-east-1'
         | 
|  | |
| 97 | 
             
            output = []
         | 
| 98 | 
             
            service = 'es'
         | 
| 99 |  | 
|  | |
| 108 | 
             
                </style>
         | 
| 109 | 
             
                """,unsafe_allow_html=True)
         | 
| 110 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 111 | 
             
            def write_logo():
         | 
| 112 | 
             
                col1, col2, col3 = st.columns([5, 1, 5])
         | 
| 113 | 
             
                with col2:
         | 
|  | |
| 119 | 
             
                    st.page_link("app.py", label=":orange[Home]", icon="🏠")
         | 
| 120 | 
             
                    st.header("AI Shopping assistant",divider='rainbow')
         | 
| 121 |  | 
|  | |
|  | |
| 122 | 
             
                with col2:
         | 
| 123 | 
             
                    st.write("")
         | 
| 124 | 
             
                    st.write("")
         | 
|  | |
| 135 | 
             
                st.session_state.input_shopping_query=""
         | 
| 136 | 
             
                st.session_state.session_id_ = str(uuid.uuid1())
         | 
| 137 | 
             
                bedrock_agent.delete_memory()
         | 
| 138 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
| 139 |  | 
| 140 |  | 
| 141 | 
             
            def handle_input():
         | 
|  | |
|  | |
|  | |
| 142 | 
             
                if(st.session_state.input_shopping_query==''):
         | 
| 143 | 
             
                    return ""
         | 
| 144 | 
             
                inputs = {}
         | 
|  | |
| 147 | 
             
                        inputs[key.removeprefix('input_')] = st.session_state[key]
         | 
| 148 | 
             
                st.session_state.inputs_ = inputs
         | 
| 149 |  | 
|  | |
|  | |
|  | |
|  | |
| 150 | 
             
                question_with_id = {
         | 
| 151 | 
             
                    'question': inputs["shopping_query"],
         | 
| 152 | 
             
                    'id': len(st.session_state.questions__)
         | 
|  | |
| 165 | 
             
                st.session_state.input_shopping_query=""
         | 
| 166 |  | 
| 167 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 168 |  | 
| 169 | 
             
            def write_user_message(md):
         | 
| 170 | 
             
                col1, col2 = st.columns([3,97])
         | 
|  | |
| 172 | 
             
                with col1:
         | 
| 173 | 
             
                    st.image(USER_ICON, use_column_width='always')
         | 
| 174 | 
             
                with col2:
         | 
|  | |
|  | |
| 175 | 
             
                    st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
         | 
| 176 |  | 
| 177 |  | 
|  | |
| 188 | 
             
                    ans_ = answer['answer']
         | 
| 189 | 
             
                    span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
         | 
| 190 | 
             
                    st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 191 | 
             
                    if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
         | 
| 192 | 
             
                        use_interim_results = True
         | 
| 193 | 
             
                        src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
         | 
|  | |
|  | |
|  | |
|  | |
| 194 | 
             
                    if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
         | 
| 195 | 
             
                        key_ = answer['last_tool']['name']
         | 
| 196 |  | 
|  | |
| 206 | 
             
                            if(index ==1):
         | 
| 207 | 
             
                                with img_col2:
         | 
| 208 | 
             
                                    st.image(resizedImg,use_column_width = True,caption = item['title'])
         | 
| 209 | 
            +
                                    
         | 
|  | |
|  | |
| 210 | 
             
                    if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):   
         | 
| 211 | 
             
                        st.write("<br>",unsafe_allow_html = True)
         | 
| 212 | 
             
                        gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
         | 
|  | |
| 222 | 
             
                        with gen_img_col1:
         | 
| 223 | 
             
                            st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
         | 
| 224 | 
             
                        st.write("<br>",unsafe_allow_html = True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 225 | 
             
                colu1,colu2,colu3 = st.columns([4,82,20])
         | 
| 226 | 
             
                if(answer['source']!={}):
         | 
| 227 | 
             
                    with colu2:
         | 
| 228 | 
             
                        with st.expander("Agent Traces:"):
         | 
| 229 | 
             
                            st.write(answer['source'])
         | 
| 230 | 
            +
                   
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 231 |  | 
| 232 | 
             
            #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
         | 
| 233 | 
             
            def write_chat_message(md, q,index):
         | 
|  | |
|  | |
| 234 | 
             
                chat = st.container()
         | 
| 235 | 
             
                with chat:
         | 
|  | |
|  | |
| 236 | 
             
                    render_answer(q,md,index)
         | 
| 237 |  | 
| 238 | 
             
            def render_all():  
         | 
|  | |
| 248 |  | 
| 249 | 
             
            st.markdown("")
         | 
| 250 | 
             
            col_2, col_3 = st.columns([75,20])
         | 
| 251 | 
            +
             
         | 
|  | |
|  | |
|  | |
| 252 | 
             
            with col_2:
         | 
|  | |
| 253 | 
             
                input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
         | 
| 254 | 
             
            with col_3:
         | 
|  | |
|  | |
|  | |
| 255 | 
             
                play = st.button("Go",on_click=handle_input,key = "play")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        pages/Semantic_Search.py
    CHANGED
    
    | @@ -24,24 +24,18 @@ import base64 | |
| 24 | 
             
            import shutil
         | 
| 25 | 
             
            import re
         | 
| 26 | 
             
            from requests.auth import HTTPBasicAuth
         | 
| 27 | 
            -
            #import utilities.re_ranker as re_ranker
         | 
| 28 | 
             
            # from nltk.stem import PorterStemmer
         | 
| 29 | 
             
            # from nltk.tokenize import word_tokenize
         | 
| 30 | 
             
            import query_rewrite
         | 
| 31 | 
             
            import amazon_rekognition
         | 
|  | |
| 32 | 
             
            #from st_click_detector import click_detector
         | 
| 33 | 
             
            import llm_eval
         | 
| 34 | 
             
            import all_search_execute
         | 
| 35 | 
             
            import warnings
         | 
| 36 |  | 
| 37 | 
             
            warnings.filterwarnings("ignore", category=DeprecationWarning)
         | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
             
            st.set_page_config(
         | 
| 43 | 
            -
                #page_title="Semantic Search using OpenSearch",
         | 
| 44 | 
            -
                #layout="wide",
         | 
| 45 | 
             
                page_icon="images/opensearch_mark_default.png"
         | 
| 46 | 
             
            )
         | 
| 47 | 
             
            parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
         | 
| @@ -58,11 +52,6 @@ st.markdown(""" | |
| 58 | 
             
            #ps = PorterStemmer()
         | 
| 59 |  | 
| 60 | 
             
            st.session_state.REGION = 'us-east-1'
         | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
            #from langchain.callbacks.base import BaseCallbackHandler
         | 
| 64 | 
            -
             | 
| 65 | 
            -
             | 
| 66 | 
             
            USER_ICON = "images/user.png"
         | 
| 67 | 
             
            AI_ICON = "images/opensearch-twitter-card.png"
         | 
| 68 | 
             
            REGENERATE_ICON = "images/regenerate.png"
         | 
| @@ -170,12 +159,6 @@ if "input_ndcg" not in st.session_state: | |
| 170 | 
             
            if "gen_image_str" not in st.session_state:
         | 
| 171 | 
             
                st.session_state.gen_image_str=""
         | 
| 172 |  | 
| 173 | 
            -
            # if "input_searchType" not in st.session_state:
         | 
| 174 | 
            -
            #     st.session_state.input_searchType = ['Keyword Search']
         | 
| 175 | 
            -
                
         | 
| 176 | 
            -
            # if "input_must" not in st.session_state:
         | 
| 177 | 
            -
            #     st.session_state.input_must = ["Category","Price","Gender","Style"]
         | 
| 178 | 
            -
                
         | 
| 179 | 
             
            if "input_NormType" not in st.session_state:
         | 
| 180 | 
             
                st.session_state.input_NormType = "min_max"
         | 
| 181 |  | 
| @@ -261,25 +244,8 @@ if(search_all_type==True): | |
| 261 | 
             
                'Multimodal Search',
         | 
| 262 | 
             
                'NeuralSparse Search',
         | 
| 263 | 
             
                ]
         | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
            #     html("""
         | 
| 267 | 
            -
            #     <script>
         | 
| 268 | 
            -
            #         // Locate elements
         | 
| 269 | 
            -
            #         var decoration = window.parent.document.querySelectorAll('[data-testid="stDecoration"]')[0];
         | 
| 270 | 
            -
            #         decoration.style.height = "3.0rem";
         | 
| 271 | 
            -
            #         decoration.style.right = "45px";
         | 
| 272 | 
            -
            #         // Adjust text decorations
         | 
| 273 | 
            -
            #         decoration.innerText = "Semantic Search with OpenSearch!"; // Replace with your desired text
         | 
| 274 | 
            -
            #         decoration.style.fontWeight = "bold";
         | 
| 275 | 
            -
            #         decoration.style.display = "flex";
         | 
| 276 | 
            -
            #         decoration.style.justifyContent = "center";
         | 
| 277 | 
            -
            #         decoration.style.alignItems = "center";
         | 
| 278 | 
            -
            #         decoration.style.fontWeight = "bold";
         | 
| 279 | 
            -
            #         decoration.style.backgroundImage = url('/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/images/service_logo.png'); // Remove background image
         | 
| 280 | 
            -
            #         decoration.style.backgroundSize = "unset"; // Remove background size
         | 
| 281 | 
            -
            #     </script>
         | 
| 282 | 
            -
            # """, width=0, height=0)
         | 
| 283 |  | 
| 284 |  | 
| 285 |  | 
| @@ -448,31 +414,12 @@ def handle_input(): | |
| 448 |  | 
| 449 |  | 
| 450 | 
             
                inputs = {}
         | 
| 451 | 
            -
                # if(st.session_state.input_imageUpload == 'yes'):
         | 
| 452 | 
            -
                #     st.session_state.input_searchType = 'Multi-modal Search'
         | 
| 453 | 
            -
                # if(st.session_state.input_sparse == 'enabled' or st.session_state.input_is_rewrite_query == 'enabled'):
         | 
| 454 | 
            -
                #     st.session_state.input_searchType = 'Keyword Search'
         | 
| 455 | 
             
                if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType):
         | 
| 456 | 
             
                    old_rekog_label = st.session_state.input_rekog_label
         | 
| 457 | 
             
                    st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog)
         | 
| 458 | 
             
                    if(st.session_state.input_text == ""):
         | 
| 459 | 
             
                        st.session_state.input_text = st.session_state.input_rekog_label
         | 
| 460 |  | 
| 461 | 
            -
                # if(st.session_state.input_imageUpload == 'yes'):
         | 
| 462 | 
            -
                #     if(st.session_state.input_searchType!='Multi-modal Search'):
         | 
| 463 | 
            -
                #         if(st.session_state.input_searchType=='Keyword Search'):
         | 
| 464 | 
            -
                #             if(st.session_state.input_rekognition != 'enabled'):
         | 
| 465 | 
            -
                #                 st.error('For Keyword Search using images, enable "Enrich metadata for Images" in the left panel',icon = "🚨")
         | 
| 466 | 
            -
                #                 #st.session_state.input_rekognition = 'enabled'
         | 
| 467 | 
            -
                #                 st.switch_page('pages/1_Semantic_Search.py')
         | 
| 468 | 
            -
                #                 #st.stop()
         | 
| 469 | 
            -
                                
         | 
| 470 | 
            -
                #         else:
         | 
| 471 | 
            -
                #             st.error('Please set the search type as "Keyword Search (enabling Enrich metadata for Images) or Multi-modal Search"',icon = "🚨")
         | 
| 472 | 
            -
                #             #st.session_state.input_searchType='Multi-modal Search'
         | 
| 473 | 
            -
                #             st.switch_page('pages/1_Semantic_Search.py')
         | 
| 474 | 
            -
                #             #st.stop()
         | 
| 475 | 
            -
                            
         | 
| 476 |  | 
| 477 | 
             
                weightage = {}
         | 
| 478 | 
             
                st.session_state.weights_ = []
         | 
| @@ -511,44 +458,13 @@ def handle_input(): | |
| 511 | 
             
                            else:
         | 
| 512 | 
             
                                weightage[original_key] = 0.0
         | 
| 513 | 
             
                                st.session_state[key] = 0.0
         | 
| 514 | 
            -
             | 
| 515 | 
            -
                    
         | 
| 516 | 
            -
               
         | 
| 517 | 
            -
                                    
         | 
| 518 | 
            -
                                    
         | 
| 519 | 
            -
                                    
         | 
| 520 | 
            -
                                
         | 
| 521 | 
            -
                          
         | 
| 522 | 
            -
                    
         | 
| 523 | 
            -
                    
         | 
| 524 | 
            -
                            
         | 
| 525 | 
            -
             | 
| 526 | 
            -
                            
         | 
| 527 | 
             
                inputs['weightage']=weightage
         | 
| 528 | 
             
                st.session_state.input_weightage = weightage
         | 
| 529 |  | 
| 530 | 
            -
                print("====================")
         | 
| 531 | 
            -
                print(st.session_state.weights_)
         | 
| 532 | 
            -
                print(st.session_state.input_weightage )
         | 
| 533 | 
            -
                print("====================")
         | 
| 534 | 
            -
                    #print("***************************")
         | 
| 535 | 
            -
                    #print(sum(weights_))
         | 
| 536 | 
            -
                    # if(sum(st.session_state.weights_)!=100):
         | 
| 537 | 
            -
                    #     st.warning('The total weight of selected search type(s) should be equal to 100',icon = "🚨")
         | 
| 538 | 
            -
                    #     refresh = st.button("Re-Enter")
         | 
| 539 | 
            -
                    #     if(refresh):
         | 
| 540 | 
            -
                    #         st.switch_page('pages/1_Semantic_Search.py')
         | 
| 541 | 
            -
                    #         st.stop()
         | 
| 542 | 
            -
                        
         | 
| 543 | 
            -
                            
         | 
| 544 | 
            -
                        #         #st.session_state.input_rekognition = 'enabled'
         | 
| 545 | 
            -
                    #     st.rerun()
         | 
| 546 | 
            -
                    
         | 
| 547 | 
            -
                    
         | 
| 548 |  | 
| 549 | 
             
                st.session_state.inputs_ = inputs
         | 
| 550 |  | 
| 551 | 
            -
                #st.write(inputs) 
         | 
| 552 | 
             
                question_with_id = {
         | 
| 553 | 
             
                    'question': inputs["text"],
         | 
| 554 | 
             
                    'id': len(st.session_state.questions)
         | 
| @@ -567,19 +483,15 @@ def handle_input(): | |
| 567 |  | 
| 568 | 
             
                if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)):
         | 
| 569 | 
             
                    query_rewrite.get_new_query_res(st.session_state.input_text)
         | 
| 570 | 
            -
                     | 
| 571 | 
            -
                    print(st.session_state.input_rewritten_query)
         | 
| 572 | 
            -
                    print("-------------------")
         | 
| 573 | 
             
                else:
         | 
| 574 | 
             
                    st.session_state.input_rewritten_query = ""
         | 
| 575 |  | 
| 576 | 
            -
             | 
| 577 | 
            -
                #     ans__ = amazon_rekognition.call(st.session_state.input_text,st.session_state.input_rekog_label)
         | 
| 578 | 
            -
                # else:
         | 
| 579 | 
             
                ans__ = all_search_execute.handler(inputs, st.session_state['session_id'])
         | 
| 580 |  | 
| 581 | 
             
                st.session_state.answers.append({
         | 
| 582 | 
            -
                    'answer': ans__ | 
| 583 | 
             
                    'search_type':inputs['searchType'],
         | 
| 584 | 
             
                    'id': len(st.session_state.questions)
         | 
| 585 | 
             
                })
         | 
| @@ -587,21 +499,8 @@ def handle_input(): | |
| 587 | 
             
                st.session_state.answers_none_rank = st.session_state.answers
         | 
| 588 | 
             
                if(st.session_state.input_evaluate == "enabled"):
         | 
| 589 | 
             
                    llm_eval.eval(st.session_state.questions, st.session_state.answers)
         | 
| 590 | 
            -
                 | 
| 591 | 
            -
                #st.session_state.input_searchType=st.session_state.input_searchType
         | 
| 592 | 
            -
             | 
| 593 | 
             
            def write_top_bar():
         | 
| 594 | 
            -
                # st.markdown("""
         | 
| 595 | 
            -
                # <style>
         | 
| 596 | 
            -
                # [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
         | 
| 597 | 
            -
                #     gap: 0rem;
         | 
| 598 | 
            -
                # }
         | 
| 599 | 
            -
                # </style>
         | 
| 600 | 
            -
                # """,unsafe_allow_html=True)
         | 
| 601 | 
            -
                #print("top bar")
         | 
| 602 | 
            -
                # st.title(':mag: AI powered OpenSearch')
         | 
| 603 | 
            -
                # st.write("")
         | 
| 604 | 
            -
                # st.write("")
         | 
| 605 | 
             
                col1, col2,col3,col4  = st.columns([2.5,35,8,7])
         | 
| 606 | 
             
                with col1:
         | 
| 607 | 
             
                    st.image(TEXT_ICON, use_column_width='always')
         | 
| @@ -630,9 +529,6 @@ def write_top_bar(): | |
| 630 | 
             
                                st.markdown("<div style = 'height:43px'></div>",unsafe_allow_html=True)
         | 
| 631 | 
             
                                st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img"))
         | 
| 632 |  | 
| 633 | 
            -
                            # image_select = st.select_slider(
         | 
| 634 | 
            -
                            #     "Select a image",
         | 
| 635 | 
            -
                            #     options=["Image 1","Image 2","Image 3"], value = None, disabled = st.session_state.radio_disabled,key = "image_select")
         | 
| 636 | 
             
                            image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled)
         | 
| 637 | 
             
                            st.markdown("""
         | 
| 638 | 
             
                                        <style>
         | 
| @@ -642,25 +538,10 @@ def write_top_bar(): | |
| 642 | 
             
                                        </style>
         | 
| 643 | 
             
                                        """,unsafe_allow_html=True)
         | 
| 644 | 
             
                            if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0):
         | 
| 645 | 
            -
                                print("image_select")
         | 
| 646 | 
            -
                                print("------------")
         | 
| 647 | 
            -
                                print(st.session_state.image_select)
         | 
| 648 | 
             
                                st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1]
         | 
| 649 | 
             
                            else:
         | 
| 650 | 
             
                                st.session_state.input_rad_1 = ""
         | 
| 651 | 
            -
                             | 
| 652 | 
            -
                            # with rad1:
         | 
| 653 | 
            -
                            #     btn1 = st.button("choose image 1", disabled = st.session_state.radio_disabled)
         | 
| 654 | 
            -
                            # with rad2:
         | 
| 655 | 
            -
                            #     btn2 = st.button("choose image 2", disabled = st.session_state.radio_disabled)
         | 
| 656 | 
            -
                            # with rad3:
         | 
| 657 | 
            -
                            #     btn3 = st.button("choose image 3", disabled = st.session_state.radio_disabled)
         | 
| 658 | 
            -
                            # if(btn1):
         | 
| 659 | 
            -
                            #     st.session_state.input_rad_1 = "1" 
         | 
| 660 | 
            -
                            # if(btn2):
         | 
| 661 | 
            -
                            #     st.session_state.input_rad_1 = "2" 
         | 
| 662 | 
            -
                            # if(btn3):
         | 
| 663 | 
            -
                            #     st.session_state.input_rad_1 = "3" 
         | 
| 664 |  | 
| 665 |  | 
| 666 | 
             
                    generate_images(tab1,gen_images)   
         | 
| @@ -669,19 +550,11 @@ def write_top_bar(): | |
| 669 | 
             
                    with tab2:
         | 
| 670 | 
             
                        st.session_state.img_doc = st.file_uploader(
         | 
| 671 | 
             
                        "Upload image", accept_multiple_files=False,type = ['png', 'jpg'])
         | 
| 672 | 
            -
                        
         | 
| 673 | 
            -
                
         | 
| 674 | 
            -
                    
         | 
| 675 | 
            -
                    
         | 
| 676 | 
            -
             | 
| 677 | 
             
                return clear,tab1
         | 
| 678 |  | 
| 679 | 
             
            clear,tab_ = write_top_bar()
         | 
| 680 |  | 
| 681 | 
             
            if clear:
         | 
| 682 | 
            -
                
         | 
| 683 | 
            -
                
         | 
| 684 | 
            -
                print("clear1")
         | 
| 685 | 
             
                st.session_state.questions = []
         | 
| 686 | 
             
                st.session_state.answers = []
         | 
| 687 |  | 
| @@ -697,18 +570,7 @@ if clear: | |
| 697 | 
             
                    st.session_state.input_rad_1 = ""
         | 
| 698 |  | 
| 699 |  | 
| 700 | 
            -
             | 
| 701 | 
            -
                    # with placeholder1.container():
         | 
| 702 | 
            -
                    #     generate_images(tab_,st.session_state.image_prompt)
         | 
| 703 | 
            -
                    
         | 
| 704 | 
            -
                    
         | 
| 705 | 
            -
                #st.session_state.input_text=""
         | 
| 706 | 
            -
                # st.session_state.input_searchType="Conversational Search (RAG)"
         | 
| 707 | 
            -
                # st.session_state.input_temperature = "0.001"
         | 
| 708 | 
            -
                # st.session_state.input_topK = 200
         | 
| 709 | 
            -
                # st.session_state.input_topP = 0.95
         | 
| 710 | 
            -
                # st.session_state.input_maxTokens = 1024
         | 
| 711 | 
            -
             | 
| 712 | 
             
            col1, col3, col4 = st.columns([70,18,12])
         | 
| 713 |  | 
| 714 | 
             
            with col1:
         | 
| @@ -732,7 +594,7 @@ with col4: | |
| 732 | 
             
                evaluate = st.toggle(' ', key = 'evaluate', disabled = False) #help = "Checking this box will use LLM to evaluate results as relevant and irrelevant. \n\n This option increases the latency")
         | 
| 733 | 
             
                if(evaluate):
         | 
| 734 | 
             
                    st.session_state.input_evaluate = "enabled"
         | 
| 735 | 
            -
                     | 
| 736 | 
             
                else:
         | 
| 737 | 
             
                    st.session_state.input_evaluate = "disabled"
         | 
| 738 |  | 
| @@ -740,11 +602,7 @@ with col4: | |
| 740 | 
             
            if(search_all_type == True or 1==1):
         | 
| 741 | 
             
                with st.sidebar:
         | 
| 742 | 
             
                    st.page_link("app.py", label=":orange[Home]", icon="🏠")
         | 
| 743 | 
            -
                     | 
| 744 | 
            -
                    #st.warning('Note: After changing any of the below settings, click "SEARCH" button or 🔄 to apply the changes', icon="⚠️")
         | 
| 745 | 
            -
                    #st.header('     :gear: :orange[Fine-tune Search]')
         | 
| 746 | 
            -
                    #st.write("Note: After changing any of the below settings, click 'SEARCH' button or '🔄' to apply the changes")
         | 
| 747 | 
            -
                    #st.subheader(':blue[Keyword Search]')
         | 
| 748 |  | 
| 749 | 
             
                    ########################## enable for query_rewrite ########################
         | 
| 750 | 
             
                    rewrite_query = st.checkbox('Auto-apply filters', key = 'query_rewrite', disabled = False, help = "Checking this box will use LLM to rewrite your query. \n\n Here your natural language query is transformed into OpenSearch query with added filters and attributes")
         | 
| @@ -754,6 +612,8 @@ if(search_all_type == True or 1==1): | |
| 754 | 
             
                            key = 'input_must',
         | 
| 755 | 
             
                           )
         | 
| 756 | 
             
                    ########################## enable for query_rewrite ########################
         | 
|  | |
|  | |
| 757 | 
             
                    ####### Filters   #########
         | 
| 758 |  | 
| 759 | 
             
                    st.subheader(':blue[Filters]')
         | 
| @@ -776,25 +636,6 @@ if(search_all_type == True or 1==1): | |
| 776 |  | 
| 777 |  | 
| 778 | 
             
                    clear_filter = st.button("Clear Filters",on_click=clear_filter)
         | 
| 779 | 
            -
                    
         | 
| 780 | 
            -
                        
         | 
| 781 | 
            -
            #             filter_place_holder = st.container()
         | 
| 782 | 
            -
            #             with filter_place_holder:
         | 
| 783 | 
            -
            #                 st.selectbox("Select one Category", ("accessories", "books","floral","furniture","hot_dispensed","jewelry","tools","apparel","cold_dispensed","food_service","groceries","housewares","outdoors","salty_snacks","videos","beauty","electronics","footwear","homedecor","instruments","seasonal"),index = None,key = "input_category")
         | 
| 784 | 
            -
            #                 st.selectbox("Select one Gender", ("male","female"),index = None,key = "input_gender")
         | 
| 785 | 
            -
            #                 st.slider("Select a range of price", 0, 2000, (0, 0),50, key = "input_price")
         | 
| 786 | 
            -
                         
         | 
| 787 | 
            -
            #             st.session_state.input_category=None
         | 
| 788 | 
            -
            #             st.session_state.input_gender=None
         | 
| 789 | 
            -
            #             st.session_state.input_price=(0,0)
         | 
| 790 | 
            -
                        
         | 
| 791 | 
            -
                    print("--------------------filters---------------")    
         | 
| 792 | 
            -
                    print(st.session_state.input_gender)
         | 
| 793 | 
            -
                    print(st.session_state.input_manual_filter)
         | 
| 794 | 
            -
                    print("--------------------filters---------------") 
         | 
| 795 | 
            -
                    
         | 
| 796 | 
            -
                    
         | 
| 797 | 
            -
                    
         | 
| 798 | 
             
                    ####### Filters   #########
         | 
| 799 |  | 
| 800 | 
             
                    if('NeuralSparse Search' in st.session_state.search_types):
         | 
| @@ -802,111 +643,21 @@ if(search_all_type == True or 1==1): | |
| 802 | 
             
                        sparse_filter = st.slider('Keep only sparse tokens with weight >=', 0.0, 1.0, 0.5,0.1,key = 'input_sparse_filter', help = 'Use this slider to set the minimum weight that the sparse vector token weights should meet, rest are filtered out')
         | 
| 803 |  | 
| 804 |  | 
| 805 | 
            -
                    #sql_query = st.checkbox('Re-write as SQL query', key = 'sql_rewrite', disabled = True, help = "In Progress")
         | 
| 806 | 
             
                    st.session_state.input_is_rewrite_query = 'disabled'
         | 
| 807 | 
             
                    st.session_state.input_is_sql_query = 'disabled'
         | 
| 808 |  | 
| 809 | 
             
                    ########################## enable for query_rewrite ########################
         | 
| 810 | 
             
                    if rewrite_query:
         | 
| 811 | 
            -
                        #st.write(st.session_state.inputs_)
         | 
| 812 | 
             
                        st.session_state.input_is_rewrite_query = 'enabled'
         | 
| 813 | 
            -
             | 
| 814 | 
            -
                    #     #st.write(st.session_state.inputs_)
         | 
| 815 | 
            -
                    #     st.session_state.input_is_sql_query = 'enabled'
         | 
| 816 | 
            -
                    ########################## enable for sql conversion ########################
         | 
| 817 | 
            -
                    
         | 
| 818 | 
            -
                    
         | 
| 819 | 
            -
                    #st.markdown('---')
         | 
| 820 | 
            -
                    #st.header('Fine-tune keyword Search', divider='rainbow')
         | 
| 821 | 
            -
                    #st.subheader('Note: The below selection applies only when the Search type is set to Keyword Search')
         | 
| 822 | 
            -
                       
         | 
| 823 | 
            -
                     
         | 
| 824 | 
            -
                    # st.markdown("<u>Enrich metadata for :</u>",unsafe_allow_html=True) 
         | 
| 825 | 
            -
                    
         | 
| 826 | 
            -
             | 
| 827 | 
            -
                    
         | 
| 828 | 
            -
                    # c3,c4 = st.columns([10,90])
         | 
| 829 | 
            -
                    # with c4:
         | 
| 830 | 
            -
                    #     rekognition = st.checkbox('Images', key = 'rekognition', help = "Checking this box will use AI to extract metadata for images that are present in query and documents")
         | 
| 831 | 
            -
                    # if rekognition:
         | 
| 832 | 
            -
                    #     #st.write(st.session_state.inputs_)
         | 
| 833 | 
            -
                    #     st.session_state.input_rekognition = 'enabled'
         | 
| 834 | 
            -
                    # else:
         | 
| 835 | 
            -
                    #     st.session_state.input_rekognition = "disabled"
         | 
| 836 | 
            -
             | 
| 837 | 
            -
                    #st.markdown('---')
         | 
| 838 | 
            -
                    #st.header('Fine-tune Hybrid Search', divider='rainbow')
         | 
| 839 | 
            -
                    #st.subheader('Note: The below parameters apply only when the Search type is set to Hybrid Search')
         | 
| 840 | 
            -
                    
         | 
| 841 | 
            -
                    
         | 
| 842 | 
            -
                    
         | 
| 843 | 
            -
                    
         | 
| 844 | 
            -
                    
         | 
| 845 | 
            -
                    
         | 
| 846 | 
            -
                    
         | 
| 847 | 
            -
                    #st.write("---")
         | 
| 848 | 
            -
                    #if(st.session_state.max_selections == "None"):
         | 
| 849 | 
             
                    st.subheader(':blue[Hybrid Search]')
         | 
| 850 | 
            -
                    # st.selectbox('Select the Hybrid Search type',
         | 
| 851 | 
            -
                    #  ("OpenSearch Hybrid Query","Reciprocal Rank Fusion"),key = 'input_hybridType')
         | 
| 852 | 
            -
                    # equal_weight = st.button("Give equal weights to selected searches")
         | 
| 853 | 
            -
             | 
| 854 | 
            -
             | 
| 855 | 
            -
             | 
| 856 | 
            -
             | 
| 857 | 
            -
             | 
| 858 | 
            -
             | 
| 859 | 
            -
                    #st.warning('Weight of each of the selected search type should be greater than 0 and the total weight of all the selected search type(s) should be equal to 100',icon = "⚠️")
         | 
| 860 | 
            -
             | 
| 861 | 
            -
             | 
| 862 | 
            -
                    #st.markdown("<p style = 'font-size:14.5px;font-style:italic;'>Set Weights</p>",unsafe_allow_html=True)
         | 
| 863 | 
            -
             | 
| 864 | 
             
                    with st.expander("Set query Weightage:"):
         | 
| 865 | 
             
                        st.number_input("Keyword %", min_value=0, max_value=100, value=100, step=5,  key='input_Keyword-weight', help=None)
         | 
| 866 | 
             
                        st.number_input("Vector %", min_value=0, max_value=100, value=0, step=5,  key='input_Vector-weight', help=None)
         | 
| 867 | 
             
                        st.number_input("Multimodal %", min_value=0, max_value=100, value=0, step=5,  key='input_Multimodal-weight', help=None)
         | 
| 868 | 
             
                        st.number_input("NeuralSparse %", min_value=0, max_value=100, value=0, step=5,  key='input_NeuralSparse-weight', help=None)
         | 
| 869 |  | 
| 870 | 
            -
             | 
| 871 | 
            -
                    #     counter = 0
         | 
| 872 | 
            -
                    #     num_search = len(st.session_state.input_searchType)
         | 
| 873 | 
            -
                    #     weight_type = ["input_Keyword-weight","input_Vector-weight","input_Multimodal-weight","input_NeuralSparse-weight"]
         | 
| 874 | 
            -
                    #     for type in weight_type:
         | 
| 875 | 
            -
                    #         if(type.split("-")[0].replace("input_","")+ " Search" in st.session_state.input_searchType):
         | 
| 876 | 
            -
                    #             print("ssssssssssss")
         | 
| 877 | 
            -
                    #             counter = counter +1
         | 
| 878 | 
            -
                    #             extra_weight = 100%num_search
         | 
| 879 | 
            -
                    #             if(counter == num_search):
         | 
| 880 | 
            -
                    #                 cal_weight = math.trunc(100/num_search)+extra_weight
         | 
| 881 | 
            -
                    #             else:
         | 
| 882 | 
            -
                    #                 cal_weight = math.trunc(100/num_search)
         | 
| 883 | 
            -
                    #             st.session_state[weight_type] = cal_weight
         | 
| 884 | 
            -
                    #         else:
         | 
| 885 | 
            -
                    #             st.session_state[weight_type] = 0
         | 
| 886 | 
            -
                    #weight = st.slider('Weight for Vector Search', 0.0, 1.0, 0.5,0.1,key = 'input_weight', help = 'Use this slider to set the weightage for keyword and vector search, higher values of the slider indicate the increased weightage for semantic search.\n\n This applies only when the search type is set to Hybrid Search')
         | 
| 887 | 
            -
                    # st.selectbox('Select the Normalisation type',
         | 
| 888 | 
            -
                    # ('min_max',
         | 
| 889 | 
            -
                    # 'l2'
         | 
| 890 | 
            -
                    # ),
         | 
| 891 | 
            -
                    #st.write("---")
         | 
| 892 | 
            -
                    # key = 'input_NormType',
         | 
| 893 | 
            -
                    # disabled = True,
         | 
| 894 | 
            -
                    # help = "Select the type of Normalisation to be applied on the two sets of scores"
         | 
| 895 | 
            -
                    # ) 
         | 
| 896 | 
            -
             | 
| 897 | 
            -
                    # st.selectbox('Select the Score Combination type',
         | 
| 898 | 
            -
                    # ('arithmetic_mean','geometric_mean','harmonic_mean'
         | 
| 899 | 
            -
                    # ),
         | 
| 900 | 
            -
                
         | 
| 901 | 
            -
                    # key = 'input_CombineType',
         | 
| 902 | 
            -
                    # disabled = True,
         | 
| 903 | 
            -
                    # help = "Select the Combination strategy to be used while combining the two scores of the two search queries for every document"
         | 
| 904 | 
            -
                    # )  
         | 
| 905 | 
            -
             | 
| 906 | 
            -
                    #st.markdown('---')
         | 
| 907 | 
            -
             | 
| 908 | 
            -
                    #st.header('Select the ML Model for text embedding', divider='rainbow')
         | 
| 909 | 
            -
                    #st.subheader('Note: The below selection applies only when the Search type is set to Vector or Hybrid Search')
         | 
| 910 | 
             
                    if(st.session_state.re_ranker == "true"):
         | 
| 911 | 
             
                        st.subheader(':blue[Re-ranking]')
         | 
| 912 | 
             
                        reranker = st.selectbox('Choose a Re-Ranker',
         | 
| @@ -916,41 +667,19 @@ if(search_all_type == True or 1==1): | |
| 916 |  | 
| 917 | 
             
                        key = 'input_reranker',
         | 
| 918 | 
             
                        help = 'Select the Re-Ranker type, select "None" to apply no re-ranking of the results',
         | 
| 919 | 
            -
                        #on_change = re_ranker.re_rank,
         | 
| 920 | 
             
                        args=(st.session_state.questions, st.session_state.answers)
         | 
| 921 |  | 
| 922 | 
             
                        )
         | 
| 923 | 
            -
             | 
| 924 | 
            -
                    # st.subheader('Text Embeddings Model')
         | 
| 925 | 
            -
                    # model_type = st.selectbox('Select the Text Embeddings Model',
         | 
| 926 | 
            -
                    # ('Titan-Embed-Text-v1','GPT-J-6B'
         | 
| 927 | 
            -
                    
         | 
| 928 | 
            -
                    # ),
         | 
| 929 | 
            -
                
         | 
| 930 | 
            -
                    # key = 'input_modelType',
         | 
| 931 | 
            -
                    # help = "Select the Text embedding model, this applies only for the vector and hybrid search"
         | 
| 932 | 
            -
                    # )
         | 
| 933 | 
            -
             | 
| 934 | 
            -
                    #st.markdown('---')
         | 
| 935 | 
            -
             | 
| 936 | 
            -
                    
         | 
| 937 | 
            -
             | 
| 938 | 
            -
                    
         | 
| 939 | 
            -
             | 
| 940 | 
            -
                
         | 
| 941 | 
            -
             | 
| 942 | 
            -
            #st.markdown('---')
         | 
| 943 |  | 
| 944 |  | 
| 945 | 
             
            def write_user_message(md,ans):
         | 
| 946 | 
            -
                #print(ans)
         | 
| 947 | 
             
                ans = ans["answer"][0]
         | 
| 948 | 
             
                col1, col2, col3 = st.columns([3,40,20])
         | 
| 949 |  | 
| 950 | 
             
                with col1:
         | 
| 951 | 
             
                    st.image(USER_ICON, use_column_width='always')
         | 
| 952 | 
             
                with col2:
         | 
| 953 | 
            -
                    #st.warning(md['question'])
         | 
| 954 | 
             
                    st.markdown("<div style='fontSize:15px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'>Input Text: </div><div style='fontSize:25px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;color:#e28743'>"+md['question']+"</div>", unsafe_allow_html = True)
         | 
| 955 | 
             
                    if('query_sparse' in ans):
         | 
| 956 | 
             
                        with st.expander("Expanded Query:"):
         | 
| @@ -1011,10 +740,7 @@ def render_answer(answer,index): | |
| 1011 | 
             
                            span_color = "red"
         | 
| 1012 | 
             
                        st.markdown("<span style='fontSize:20px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>Relevance:" +str('%.3f'%(st.session_state.input_ndcg)) + "</span><span style='font-size:30px;font-weight:bold;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-weight:bold;font-family:Courier New;color:"+span_color+"'> "+st.session_state.ndcg_increase.split("~")[1]+"</span>", unsafe_allow_html = True)
         | 
| 1013 |  | 
| 1014 | 
            -
             | 
| 1015 | 
            -
                        #st.markdown("<span style='font-size:30px;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-family:Courier New;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[1]+"</span>",unsafe_allow_html = True)
         | 
| 1016 | 
            -
                    
         | 
| 1017 | 
            -
                
         | 
| 1018 |  | 
| 1019 | 
             
                placeholder_no_results  = st.empty()
         | 
| 1020 |  | 
| @@ -1030,12 +756,7 @@ def render_answer(answer,index): | |
| 1030 | 
             
                        continue
         | 
| 1031 |  | 
| 1032 |  | 
| 1033 | 
            -
                    # imgdata = base64.b64decode(ans['image_binary'])
         | 
| 1034 | 
             
                    format_ = ans['image_url'].split(".")[-1]
         | 
| 1035 | 
            -
                   
         | 
| 1036 | 
            -
                    #urllib.request.urlretrieve(ans['image_url'], "/home/ubuntu/res_images/"+str(i)+"_."+format_) 
         | 
| 1037 | 
            -
             | 
| 1038 | 
            -
                    
         | 
| 1039 | 
             
                    Image.MAX_IMAGE_PIXELS = 100000000
         | 
| 1040 |  | 
| 1041 | 
             
                    width = 500
         | 
| @@ -1066,23 +787,6 @@ def render_answer(answer,index): | |
| 1066 | 
             
                                desc__ = ans['desc'].split(" ")
         | 
| 1067 |  | 
| 1068 | 
             
                                final_desc = "<p>"
         | 
| 1069 | 
            -
                                
         | 
| 1070 | 
            -
                                ###### stemming and highlighting
         | 
| 1071 | 
            -
                                
         | 
| 1072 | 
            -
                                # ans_text = ans['desc']
         | 
| 1073 | 
            -
                                # query_text = st.session_state.input_text
         | 
| 1074 | 
            -
             | 
| 1075 | 
            -
                                # ans_text_stemmed = set(stem_(ans_text))
         | 
| 1076 | 
            -
                                # query_text_stemmed = set(stem_(query_text))
         | 
| 1077 | 
            -
             | 
| 1078 | 
            -
                                # common = ans_text_stemmed.intersection( query_text_stemmed)
         | 
| 1079 | 
            -
                                # #unique = set(document_1_words).symmetric_difference(  )
         | 
| 1080 | 
            -
             | 
| 1081 | 
            -
                                # desc__stemmed = stem_(desc__)
         | 
| 1082 | 
            -
             | 
| 1083 | 
            -
                                # for word_ in desc__stemmed:
         | 
| 1084 | 
            -
                                #     if(word_ in common):
         | 
| 1085 | 
            -
             | 
| 1086 |  | 
| 1087 | 
             
                                for word in desc__:
         | 
| 1088 | 
             
                                    if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
         | 
| @@ -1104,16 +808,8 @@ def render_answer(answer,index): | |
| 1104 | 
             
                                            filtered_sparse[key] = round(sparse_[key], 2)
         | 
| 1105 | 
             
                                    st.write(filtered_sparse)
         | 
| 1106 | 
             
                            with st.expander("Document Metadata:",expanded = False):
         | 
| 1107 | 
            -
                                # if("rekog" in ans):
         | 
| 1108 | 
            -
                                #     div_size = [50,50]
         | 
| 1109 | 
            -
                                # else:
         | 
| 1110 | 
            -
                                #     div_size = [99,1]
         | 
| 1111 | 
            -
                                # div1,div2 = st.columns(div_size)
         | 
| 1112 | 
            -
                                # with div1:
         | 
| 1113 | 
            -
                                    
         | 
| 1114 | 
             
                                st.write(":green[default:]")
         | 
| 1115 | 
             
                                st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
         | 
| 1116 | 
            -
                                #with div2:
         | 
| 1117 | 
             
                                if("rekog" in ans):
         | 
| 1118 | 
             
                                    st.write(":green[enriched:]")
         | 
| 1119 | 
             
                                    st.json(ans['rekog'],expanded = True)
         | 
| @@ -1128,18 +824,7 @@ def render_answer(answer,index): | |
| 1128 | 
             
                                            st.write(":x:")
         | 
| 1129 |  | 
| 1130 | 
             
                    i = i+1
         | 
| 1131 | 
            -
                 | 
| 1132 | 
            -
                #     if(st.session_state.input_evaluate == "enabled"):
         | 
| 1133 | 
            -
                #         st.markdown("<div style='fontSize:12px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;font-weight:bold;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>DCG: " +str('%.3f'%(st.session_state.input_ndcg)) + "</div>", unsafe_allow_html = True)
         | 
| 1134 | 
            -
                # with col_2_b:
         | 
| 1135 | 
            -
                #     span_color = "white"
         | 
| 1136 | 
            -
                #     if("↑" in st.session_state.ndcg_increase):
         | 
| 1137 | 
            -
                #         span_color = "green"
         | 
| 1138 | 
            -
                #     if("↓" in st.session_state.ndcg_increase):
         | 
| 1139 | 
            -
                #         span_color = "red"
         | 
| 1140 | 
            -
                #     st.markdown("<span style='font-size:30px;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-family:Courier New;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[1]+"</span>",unsafe_allow_html = True)
         | 
| 1141 | 
            -
                        
         | 
| 1142 | 
            -
                        
         | 
| 1143 | 
             
                with col_3:
         | 
| 1144 | 
             
                    if(index == len(st.session_state.questions)):
         | 
| 1145 |  | 
| @@ -1155,7 +840,6 @@ def render_answer(answer,index): | |
| 1155 | 
             
                                st.session_state.questions.pop()
         | 
| 1156 |  | 
| 1157 | 
             
                                handle_input()
         | 
| 1158 | 
            -
                                #re_ranker.re_rank(st.session_state.questions, st.session_state.answers)
         | 
| 1159 | 
             
                                with placeholder.container():
         | 
| 1160 | 
             
                                    render_all()
         | 
| 1161 |  | 
| @@ -1169,9 +853,6 @@ def render_answer(answer,index): | |
| 1169 | 
             
                        except:
         | 
| 1170 | 
             
                            pass  
         | 
| 1171 |  | 
| 1172 | 
            -
                        print("------------------------")
         | 
| 1173 | 
            -
                        #print(st.session_state)
         | 
| 1174 | 
            -
             | 
| 1175 | 
             
                        placeholder__ = st.empty()
         | 
| 1176 |  | 
| 1177 | 
             
                        placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True)
         | 
| @@ -1196,8 +877,6 @@ def render_all(): | |
| 1196 | 
             
                index = 0
         | 
| 1197 | 
             
                for (q, a) in zip(st.session_state.questions, st.session_state.answers):
         | 
| 1198 | 
             
                    index = index +1
         | 
| 1199 | 
            -
                    #print("answers----")
         | 
| 1200 | 
            -
                    #print(a)
         | 
| 1201 | 
             
                    ans_ = st.session_state.answers[0]
         | 
| 1202 | 
             
                    write_user_message(q,ans_)
         | 
| 1203 | 
             
                    write_chat_message(a, q,index)
         | 
| @@ -1206,6 +885,4 @@ placeholder = st.empty() | |
| 1206 | 
             
            with placeholder.container():
         | 
| 1207 | 
             
              render_all()
         | 
| 1208 |  | 
| 1209 | 
            -
              #generate_images("",st.session_state.image_prompt)
         | 
| 1210 | 
            -
             | 
| 1211 | 
             
            st.markdown("")
         | 
|  | |
| 24 | 
             
            import shutil
         | 
| 25 | 
             
            import re
         | 
| 26 | 
             
            from requests.auth import HTTPBasicAuth
         | 
|  | |
| 27 | 
             
            # from nltk.stem import PorterStemmer
         | 
| 28 | 
             
            # from nltk.tokenize import word_tokenize
         | 
| 29 | 
             
            import query_rewrite
         | 
| 30 | 
             
            import amazon_rekognition
         | 
| 31 | 
            +
            from streamlit.components.v1 import html
         | 
| 32 | 
             
            #from st_click_detector import click_detector
         | 
| 33 | 
             
            import llm_eval
         | 
| 34 | 
             
            import all_search_execute
         | 
| 35 | 
             
            import warnings
         | 
| 36 |  | 
| 37 | 
             
            warnings.filterwarnings("ignore", category=DeprecationWarning)
         | 
|  | |
|  | |
|  | |
|  | |
| 38 | 
             
            st.set_page_config(
         | 
|  | |
|  | |
| 39 | 
             
                page_icon="images/opensearch_mark_default.png"
         | 
| 40 | 
             
            )
         | 
| 41 | 
             
            parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
         | 
|  | |
| 52 | 
             
            #ps = PorterStemmer()
         | 
| 53 |  | 
| 54 | 
             
            st.session_state.REGION = 'us-east-1'
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 55 | 
             
            USER_ICON = "images/user.png"
         | 
| 56 | 
             
            AI_ICON = "images/opensearch-twitter-card.png"
         | 
| 57 | 
             
            REGENERATE_ICON = "images/regenerate.png"
         | 
|  | |
| 159 | 
             
            if "gen_image_str" not in st.session_state:
         | 
| 160 | 
             
                st.session_state.gen_image_str=""
         | 
| 161 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 162 | 
             
            if "input_NormType" not in st.session_state:
         | 
| 163 | 
             
                st.session_state.input_NormType = "min_max"
         | 
| 164 |  | 
|  | |
| 244 | 
             
                'Multimodal Search',
         | 
| 245 | 
             
                'NeuralSparse Search',
         | 
| 246 | 
             
                ]
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 249 |  | 
| 250 |  | 
| 251 |  | 
|  | |
| 414 |  | 
| 415 |  | 
| 416 | 
             
                inputs = {}
         | 
|  | |
|  | |
|  | |
|  | |
| 417 | 
             
                if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType):
         | 
| 418 | 
             
                    old_rekog_label = st.session_state.input_rekog_label
         | 
| 419 | 
             
                    st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog)
         | 
| 420 | 
             
                    if(st.session_state.input_text == ""):
         | 
| 421 | 
             
                        st.session_state.input_text = st.session_state.input_rekog_label
         | 
| 422 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 423 |  | 
| 424 | 
             
                weightage = {}
         | 
| 425 | 
             
                st.session_state.weights_ = []
         | 
|  | |
| 458 | 
             
                            else:
         | 
| 459 | 
             
                                weightage[original_key] = 0.0
         | 
| 460 | 
             
                                st.session_state[key] = 0.0
         | 
| 461 | 
            +
                
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 462 | 
             
                inputs['weightage']=weightage
         | 
| 463 | 
             
                st.session_state.input_weightage = weightage
         | 
| 464 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 465 |  | 
| 466 | 
             
                st.session_state.inputs_ = inputs
         | 
| 467 |  | 
|  | |
| 468 | 
             
                question_with_id = {
         | 
| 469 | 
             
                    'question': inputs["text"],
         | 
| 470 | 
             
                    'id': len(st.session_state.questions)
         | 
|  | |
| 483 |  | 
| 484 | 
             
                if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)):
         | 
| 485 | 
             
                    query_rewrite.get_new_query_res(st.session_state.input_text)
         | 
| 486 | 
            +
                    
         | 
|  | |
|  | |
| 487 | 
             
                else:
         | 
| 488 | 
             
                    st.session_state.input_rewritten_query = ""
         | 
| 489 |  | 
| 490 | 
            +
             | 
|  | |
|  | |
| 491 | 
             
                ans__ = all_search_execute.handler(inputs, st.session_state['session_id'])
         | 
| 492 |  | 
| 493 | 
             
                st.session_state.answers.append({
         | 
| 494 | 
            +
                    'answer': ans__,
         | 
| 495 | 
             
                    'search_type':inputs['searchType'],
         | 
| 496 | 
             
                    'id': len(st.session_state.questions)
         | 
| 497 | 
             
                })
         | 
|  | |
| 499 | 
             
                st.session_state.answers_none_rank = st.session_state.answers
         | 
| 500 | 
             
                if(st.session_state.input_evaluate == "enabled"):
         | 
| 501 | 
             
                    llm_eval.eval(st.session_state.questions, st.session_state.answers)
         | 
| 502 | 
            +
                
         | 
|  | |
|  | |
| 503 | 
             
            def write_top_bar():
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 504 | 
             
                col1, col2,col3,col4  = st.columns([2.5,35,8,7])
         | 
| 505 | 
             
                with col1:
         | 
| 506 | 
             
                    st.image(TEXT_ICON, use_column_width='always')
         | 
|  | |
| 529 | 
             
                                st.markdown("<div style = 'height:43px'></div>",unsafe_allow_html=True)
         | 
| 530 | 
             
                                st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img"))
         | 
| 531 |  | 
|  | |
|  | |
|  | |
| 532 | 
             
                            image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled)
         | 
| 533 | 
             
                            st.markdown("""
         | 
| 534 | 
             
                                        <style>
         | 
|  | |
| 538 | 
             
                                        </style>
         | 
| 539 | 
             
                                        """,unsafe_allow_html=True)
         | 
| 540 | 
             
                            if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0):
         | 
|  | |
|  | |
|  | |
| 541 | 
             
                                st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1]
         | 
| 542 | 
             
                            else:
         | 
| 543 | 
             
                                st.session_state.input_rad_1 = ""
         | 
| 544 | 
            +
                            
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 545 |  | 
| 546 |  | 
| 547 | 
             
                    generate_images(tab1,gen_images)   
         | 
|  | |
| 550 | 
             
                    with tab2:
         | 
| 551 | 
             
                        st.session_state.img_doc = st.file_uploader(
         | 
| 552 | 
             
                        "Upload image", accept_multiple_files=False,type = ['png', 'jpg'])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 553 | 
             
                return clear,tab1
         | 
| 554 |  | 
| 555 | 
             
            clear,tab_ = write_top_bar()
         | 
| 556 |  | 
| 557 | 
             
            if clear:
         | 
|  | |
|  | |
|  | |
| 558 | 
             
                st.session_state.questions = []
         | 
| 559 | 
             
                st.session_state.answers = []
         | 
| 560 |  | 
|  | |
| 570 | 
             
                    st.session_state.input_rad_1 = ""
         | 
| 571 |  | 
| 572 |  | 
| 573 | 
            +
                   
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 574 | 
             
            col1, col3, col4 = st.columns([70,18,12])
         | 
| 575 |  | 
| 576 | 
             
            with col1:
         | 
|  | |
| 594 | 
             
                evaluate = st.toggle(' ', key = 'evaluate', disabled = False) #help = "Checking this box will use LLM to evaluate results as relevant and irrelevant. \n\n This option increases the latency")
         | 
| 595 | 
             
                if(evaluate):
         | 
| 596 | 
             
                    st.session_state.input_evaluate = "enabled"
         | 
| 597 | 
            +
                    
         | 
| 598 | 
             
                else:
         | 
| 599 | 
             
                    st.session_state.input_evaluate = "disabled"
         | 
| 600 |  | 
|  | |
| 602 | 
             
            if(search_all_type == True or 1==1):
         | 
| 603 | 
             
                with st.sidebar:
         | 
| 604 | 
             
                    st.page_link("app.py", label=":orange[Home]", icon="🏠")
         | 
| 605 | 
            +
                    
         | 
|  | |
|  | |
|  | |
|  | |
| 606 |  | 
| 607 | 
             
                    ########################## enable for query_rewrite ########################
         | 
| 608 | 
             
                    rewrite_query = st.checkbox('Auto-apply filters', key = 'query_rewrite', disabled = False, help = "Checking this box will use LLM to rewrite your query. \n\n Here your natural language query is transformed into OpenSearch query with added filters and attributes")
         | 
|  | |
| 612 | 
             
                            key = 'input_must',
         | 
| 613 | 
             
                           )
         | 
| 614 | 
             
                    ########################## enable for query_rewrite ########################
         | 
| 615 | 
            +
             | 
| 616 | 
            +
             | 
| 617 | 
             
                    ####### Filters   #########
         | 
| 618 |  | 
| 619 | 
             
                    st.subheader(':blue[Filters]')
         | 
|  | |
| 636 |  | 
| 637 |  | 
| 638 | 
             
                    clear_filter = st.button("Clear Filters",on_click=clear_filter)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 639 | 
             
                    ####### Filters   #########
         | 
| 640 |  | 
| 641 | 
             
                    if('NeuralSparse Search' in st.session_state.search_types):
         | 
|  | |
| 643 | 
             
                        sparse_filter = st.slider('Keep only sparse tokens with weight >=', 0.0, 1.0, 0.5,0.1,key = 'input_sparse_filter', help = 'Use this slider to set the minimum weight that the sparse vector token weights should meet, rest are filtered out')
         | 
| 644 |  | 
| 645 |  | 
|  | |
| 646 | 
             
                    st.session_state.input_is_rewrite_query = 'disabled'
         | 
| 647 | 
             
                    st.session_state.input_is_sql_query = 'disabled'
         | 
| 648 |  | 
| 649 | 
             
                    ########################## enable for query_rewrite ########################
         | 
| 650 | 
             
                    if rewrite_query:
         | 
|  | |
| 651 | 
             
                        st.session_state.input_is_rewrite_query = 'enabled'
         | 
| 652 | 
            +
                   
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 653 | 
             
                    st.subheader(':blue[Hybrid Search]')
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 654 | 
             
                    with st.expander("Set query Weightage:"):
         | 
| 655 | 
             
                        st.number_input("Keyword %", min_value=0, max_value=100, value=100, step=5,  key='input_Keyword-weight', help=None)
         | 
| 656 | 
             
                        st.number_input("Vector %", min_value=0, max_value=100, value=0, step=5,  key='input_Vector-weight', help=None)
         | 
| 657 | 
             
                        st.number_input("Multimodal %", min_value=0, max_value=100, value=0, step=5,  key='input_Multimodal-weight', help=None)
         | 
| 658 | 
             
                        st.number_input("NeuralSparse %", min_value=0, max_value=100, value=0, step=5,  key='input_NeuralSparse-weight', help=None)
         | 
| 659 |  | 
| 660 | 
            +
                   
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 661 | 
             
                    if(st.session_state.re_ranker == "true"):
         | 
| 662 | 
             
                        st.subheader(':blue[Re-ranking]')
         | 
| 663 | 
             
                        reranker = st.selectbox('Choose a Re-Ranker',
         | 
|  | |
| 667 |  | 
| 668 | 
             
                        key = 'input_reranker',
         | 
| 669 | 
             
                        help = 'Select the Re-Ranker type, select "None" to apply no re-ranking of the results',
         | 
|  | |
| 670 | 
             
                        args=(st.session_state.questions, st.session_state.answers)
         | 
| 671 |  | 
| 672 | 
             
                        )
         | 
| 673 | 
            +
                   
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 674 |  | 
| 675 |  | 
| 676 | 
             
            def write_user_message(md,ans):
         | 
|  | |
| 677 | 
             
                ans = ans["answer"][0]
         | 
| 678 | 
             
                col1, col2, col3 = st.columns([3,40,20])
         | 
| 679 |  | 
| 680 | 
             
                with col1:
         | 
| 681 | 
             
                    st.image(USER_ICON, use_column_width='always')
         | 
| 682 | 
             
                with col2:
         | 
|  | |
| 683 | 
             
                    st.markdown("<div style='fontSize:15px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'>Input Text: </div><div style='fontSize:25px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;color:#e28743'>"+md['question']+"</div>", unsafe_allow_html = True)
         | 
| 684 | 
             
                    if('query_sparse' in ans):
         | 
| 685 | 
             
                        with st.expander("Expanded Query:"):
         | 
|  | |
| 740 | 
             
                            span_color = "red"
         | 
| 741 | 
             
                        st.markdown("<span style='fontSize:20px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>Relevance:" +str('%.3f'%(st.session_state.input_ndcg)) + "</span><span style='font-size:30px;font-weight:bold;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-weight:bold;font-family:Courier New;color:"+span_color+"'> "+st.session_state.ndcg_increase.split("~")[1]+"</span>", unsafe_allow_html = True)
         | 
| 742 |  | 
| 743 | 
            +
                       
         | 
|  | |
|  | |
|  | |
| 744 |  | 
| 745 | 
             
                placeholder_no_results  = st.empty()
         | 
| 746 |  | 
|  | |
| 756 | 
             
                        continue
         | 
| 757 |  | 
| 758 |  | 
|  | |
| 759 | 
             
                    format_ = ans['image_url'].split(".")[-1]
         | 
|  | |
|  | |
|  | |
|  | |
| 760 | 
             
                    Image.MAX_IMAGE_PIXELS = 100000000
         | 
| 761 |  | 
| 762 | 
             
                    width = 500
         | 
|  | |
| 787 | 
             
                                desc__ = ans['desc'].split(" ")
         | 
| 788 |  | 
| 789 | 
             
                                final_desc = "<p>"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 790 |  | 
| 791 | 
             
                                for word in desc__:
         | 
| 792 | 
             
                                    if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
         | 
|  | |
| 808 | 
             
                                            filtered_sparse[key] = round(sparse_[key], 2)
         | 
| 809 | 
             
                                    st.write(filtered_sparse)
         | 
| 810 | 
             
                            with st.expander("Document Metadata:",expanded = False):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 811 | 
             
                                st.write(":green[default:]")
         | 
| 812 | 
             
                                st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
         | 
|  | |
| 813 | 
             
                                if("rekog" in ans):
         | 
| 814 | 
             
                                    st.write(":green[enriched:]")
         | 
| 815 | 
             
                                    st.json(ans['rekog'],expanded = True)
         | 
|  | |
| 824 | 
             
                                            st.write(":x:")
         | 
| 825 |  | 
| 826 | 
             
                    i = i+1
         | 
| 827 | 
            +
                
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 828 | 
             
                with col_3:
         | 
| 829 | 
             
                    if(index == len(st.session_state.questions)):
         | 
| 830 |  | 
|  | |
| 840 | 
             
                                st.session_state.questions.pop()
         | 
| 841 |  | 
| 842 | 
             
                                handle_input()
         | 
|  | |
| 843 | 
             
                                with placeholder.container():
         | 
| 844 | 
             
                                    render_all()
         | 
| 845 |  | 
|  | |
| 853 | 
             
                        except:
         | 
| 854 | 
             
                            pass  
         | 
| 855 |  | 
|  | |
|  | |
|  | |
| 856 | 
             
                        placeholder__ = st.empty()
         | 
| 857 |  | 
| 858 | 
             
                        placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True)
         | 
|  | |
| 877 | 
             
                index = 0
         | 
| 878 | 
             
                for (q, a) in zip(st.session_state.questions, st.session_state.answers):
         | 
| 879 | 
             
                    index = index +1
         | 
|  | |
|  | |
| 880 | 
             
                    ans_ = st.session_state.answers[0]
         | 
| 881 | 
             
                    write_user_message(q,ans_)
         | 
| 882 | 
             
                    write_chat_message(a, q,index)
         | 
|  | |
| 885 | 
             
            with placeholder.container():
         | 
| 886 | 
             
              render_all()
         | 
| 887 |  | 
|  | |
|  | |
| 888 | 
             
            st.markdown("")
         | 
    	
        semantic_search/amazon_rekognition.py
    CHANGED
    
    | @@ -24,12 +24,7 @@ def extract_image_metadata(img): | |
| 24 | 
             
                MaxLabels = 10,
         | 
| 25 | 
             
                MinConfidence = 80.0,
         | 
| 26 | 
             
                Settings = { 
         | 
| 27 | 
            -
                 | 
| 28 | 
            -
                #          "LabelCategoryExclusionFilters": [ "string" ],
         | 
| 29 | 
            -
                #          "LabelCategoryInclusionFilters": [ "string" ],
         | 
| 30 | 
            -
                #          "LabelExclusionFilters": [ "string" ],
         | 
| 31 | 
            -
                #          "LabelInclusionFilters": [ "string" ]
         | 
| 32 | 
            -
                #       },
         | 
| 33 | 
             
                    "ImageProperties": { 
         | 
| 34 | 
             
                        "MaxDominantColors": 5
         | 
| 35 | 
             
                    }
         | 
| @@ -76,20 +71,12 @@ def extract_image_metadata(img): | |
| 76 | 
             
                objects = " ".join(set(objects))
         | 
| 77 | 
             
                categories = " ".join(set(categories))
         | 
| 78 | 
             
                colors = " ".join(set(colors))
         | 
| 79 | 
            -
                
         | 
| 80 | 
            -
                print("^^^^^^^^^^^^^^^^^^")
         | 
| 81 | 
            -
                print(colors+ " " + objects + " " + categories)
         | 
| 82 | 
            -
                
         | 
| 83 | 
             
                return colors+ " " + objects + " " + categories
         | 
| 84 |  | 
| 85 | 
             
            def call(a,b):
         | 
| 86 | 
            -
                print("'''''''''''''''''''''''")
         | 
| 87 | 
            -
                print(b)
         | 
| 88 | 
            -
                
         | 
| 89 | 
             
                if(st.session_state.input_is_rewrite_query == 'enabled' and st.session_state.input_rewritten_query!=""):
         | 
| 90 |  | 
| 91 |  | 
| 92 | 
            -
                    #st.session_state.input_rewritten_query['query']['bool']['should'].pop()
         | 
| 93 | 
             
                    st.session_state.input_rewritten_query['query']['bool']['should'].append( {
         | 
| 94 | 
             
                                "simple_query_string": {
         | 
| 95 |  | 
| @@ -112,36 +99,4 @@ def call(a,b): | |
| 112 | 
             
                        }
         | 
| 113 | 
             
                    st.session_state.input_rewritten_query = rekog_query
         | 
| 114 |  | 
| 115 | 
            -
                 | 
| 116 | 
            -
                #     body = rekog_query,
         | 
| 117 | 
            -
                #     index = 'demo-retail-rekognition'
         | 
| 118 | 
            -
                #     #pipeline = 'RAG-Search-Pipeline'
         | 
| 119 | 
            -
                # )
         | 
| 120 | 
            -
                
         | 
| 121 | 
            -
                
         | 
| 122 | 
            -
                # hits = response['hits']['hits']
         | 
| 123 | 
            -
                # print("rewrite-------------------------")
         | 
| 124 | 
            -
                # arr = []
         | 
| 125 | 
            -
                # for doc in hits:
         | 
| 126 | 
            -
                #     # if('b5/b5319e00' in doc['_source']['image_s3_url'] ):
         | 
| 127 | 
            -
                #     #     filter_out +=1
         | 
| 128 | 
            -
                #     #     continue
         | 
| 129 | 
            -
                    
         | 
| 130 | 
            -
                #     res_ = {"desc":doc['_source']['text'].replace(doc['_source']['metadata']['rekog_all']," ^^^ " +doc['_source']['metadata']['rekog_all']),
         | 
| 131 | 
            -
                #             "image_url":doc['_source']['metadata']['image_s3_url']}
         | 
| 132 | 
            -
                #     if('highlight' in doc):
         | 
| 133 | 
            -
                #         res_['highlight'] = doc['highlight']['text']
         | 
| 134 | 
            -
                #     # if('caption_embedding' in doc['_source']):
         | 
| 135 | 
            -
                #     #     res_['sparse'] = doc['_source']['caption_embedding']
         | 
| 136 | 
            -
                #     # if('query_sparse' in response_ and len(arr) ==0 ):
         | 
| 137 | 
            -
                #     #     res_['query_sparse'] = response_["query_sparse"]
         | 
| 138 | 
            -
                #     res_['id'] = doc['_id']
         | 
| 139 | 
            -
                #     res_['score'] = doc['_score']
         | 
| 140 | 
            -
                #     res_['title'] = doc['_source']['text']
         | 
| 141 | 
            -
                #     res_['rekog'] = {'color':doc['_source']['metadata']['rekog_color'],'category': doc['_source']['metadata']['rekog_categories'],'objects':doc['_source']['metadata']['rekog_objects']}
         | 
| 142 | 
            -
                       
         | 
| 143 | 
            -
                #     arr.append(res_)
         | 
| 144 | 
            -
                        
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                
         | 
| 147 | 
            -
                # return arr
         | 
|  | |
| 24 | 
             
                MaxLabels = 10,
         | 
| 25 | 
             
                MinConfidence = 80.0,
         | 
| 26 | 
             
                Settings = { 
         | 
| 27 | 
            +
                
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
                    "ImageProperties": { 
         | 
| 29 | 
             
                        "MaxDominantColors": 5
         | 
| 30 | 
             
                    }
         | 
|  | |
| 71 | 
             
                objects = " ".join(set(objects))
         | 
| 72 | 
             
                categories = " ".join(set(categories))
         | 
| 73 | 
             
                colors = " ".join(set(colors))
         | 
|  | |
|  | |
|  | |
|  | |
| 74 | 
             
                return colors+ " " + objects + " " + categories
         | 
| 75 |  | 
| 76 | 
             
            def call(a,b):
         | 
|  | |
|  | |
|  | |
| 77 | 
             
                if(st.session_state.input_is_rewrite_query == 'enabled' and st.session_state.input_rewritten_query!=""):
         | 
| 78 |  | 
| 79 |  | 
|  | |
| 80 | 
             
                    st.session_state.input_rewritten_query['query']['bool']['should'].append( {
         | 
| 81 | 
             
                                "simple_query_string": {
         | 
| 82 |  | 
|  | |
| 99 | 
             
                        }
         | 
| 100 | 
             
                    st.session_state.input_rewritten_query = rekog_query
         | 
| 101 |  | 
| 102 | 
            +
                
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        utilities/invoke_models.py
    CHANGED
    
    | @@ -24,17 +24,6 @@ bedrock_runtime_client = get_bedrock_client() | |
| 24 |  | 
| 25 |  | 
| 26 |  | 
| 27 | 
            -
            # def generate_image_captions_ml():
         | 
| 28 | 
            -
            #     model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
         | 
| 29 | 
            -
            #     feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
         | 
| 30 | 
            -
            #     tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
         | 
| 31 | 
            -
             | 
| 32 | 
            -
            #     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 33 | 
            -
            #     model.to(device)
         | 
| 34 | 
            -
            #     max_length = 16
         | 
| 35 | 
            -
            #     num_beams = 4
         | 
| 36 | 
            -
            #     gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
         | 
| 37 | 
            -
             | 
| 38 | 
             
            def invoke_model(input):
         | 
| 39 | 
             
                response = bedrock_runtime_client.invoke_model(
         | 
| 40 | 
             
                    body=json.dumps({
         | 
| @@ -100,56 +89,7 @@ def invoke_llm_model(input,is_stream): | |
| 100 |  | 
| 101 | 
             
                    return (json.loads(res))['content'][0]['text']
         | 
| 102 |  | 
| 103 | 
            -
             | 
| 104 | 
            -
                    # body=json.dumps({
         | 
| 105 | 
            -
                    #     "prompt": input,
         | 
| 106 | 
            -
                    #     "max_tokens_to_sample": 300,
         | 
| 107 | 
            -
                    #     "temperature": 0.5,
         | 
| 108 | 
            -
                    #     "top_k": 250,
         | 
| 109 | 
            -
                    #     "top_p": 1,
         | 
| 110 | 
            -
                    #     "stop_sequences": [
         | 
| 111 | 
            -
                    #         "\n\nHuman:"
         | 
| 112 | 
            -
                    #     ],
         | 
| 113 | 
            -
                    #     # "anthropic_version": "bedrock-2023-05-31"
         | 
| 114 | 
            -
                    # }),
         | 
| 115 | 
            -
                    # modelId="anthropic.claude-v2:1",
         | 
| 116 | 
            -
                    # accept="application/json",
         | 
| 117 | 
            -
                    # contentType="application/json",
         | 
| 118 | 
            -
                    # )
         | 
| 119 | 
            -
                    # stream = response.get('body')
         | 
| 120 | 
            -
                    
         | 
| 121 | 
            -
                    # return stream
         | 
| 122 | 
            -
                    
         | 
| 123 | 
            -
                # else:
         | 
| 124 | 
            -
                #     response = bedrock_runtime_client.invoke_model_with_response_stream( 
         | 
| 125 | 
            -
                #         modelId= "anthropic.claude-3-sonnet-20240229-v1:0",
         | 
| 126 | 
            -
                #         contentType = "application/json",
         | 
| 127 | 
            -
                #         accept = "application/json",
         | 
| 128 | 
            -
               
         | 
| 129 | 
            -
                #         body = json.dumps({
         | 
| 130 | 
            -
                #                     "anthropic_version": "bedrock-2023-05-31",
         | 
| 131 | 
            -
                #                     "max_tokens": 1024,
         | 
| 132 | 
            -
                #                     "temperature": 0.0001,
         | 
| 133 | 
            -
                #                     "top_k": 150,
         | 
| 134 | 
            -
                #                     "top_p": 0.7,
         | 
| 135 | 
            -
                #                     "stop_sequences": [
         | 
| 136 | 
            -
                #                         "\n\nHuman:"
         | 
| 137 | 
            -
                #                     ],
         | 
| 138 | 
            -
                #                     "messages": [
         | 
| 139 | 
            -
                #                     {
         | 
| 140 | 
            -
                #                         "role": "user",
         | 
| 141 | 
            -
                #                         "content":input
         | 
| 142 | 
            -
                #                         }
         | 
| 143 | 
            -
                #                         ]
         | 
| 144 | 
            -
                #                     }
         | 
| 145 | 
            -
                                    
         | 
| 146 | 
            -
                #                      )
         | 
| 147 | 
            -
                #         )
         | 
| 148 | 
            -
                    
         | 
| 149 | 
            -
                #     stream = response.get('body')
         | 
| 150 | 
            -
                    
         | 
| 151 | 
            -
                #     return stream
         | 
| 152 | 
            -
                    
         | 
| 153 | 
             
            def read_from_table(file,question):
         | 
| 154 | 
             
                print("started table analysis:")
         | 
| 155 | 
             
                print("-----------------------")
         | 
| @@ -175,7 +115,6 @@ def read_from_table(file,question): | |
| 175 | 
             
                    df = pd.read_csv(file,skipinitialspace = True, on_bad_lines='skip',delimiter = "`")
         | 
| 176 | 
             
                else:
         | 
| 177 | 
             
                    df = file
         | 
| 178 | 
            -
                #df.fillna(method='pad', inplace=True)
         | 
| 179 | 
             
                agent = create_pandas_dataframe_agent(
         | 
| 180 | 
             
                         model, 
         | 
| 181 | 
             
                         df, 
         | 
| @@ -188,24 +127,7 @@ def read_from_table(file,question): | |
| 188 |  | 
| 189 | 
             
            def generate_image_captions_llm(base64_string,question):
         | 
| 190 |  | 
| 191 | 
            -
             | 
| 192 | 
            -
                # MODEL_NAME = "claude-3-opus-20240229"
         | 
| 193 | 
            -
                    
         | 
| 194 | 
            -
                # message_list = [
         | 
| 195 | 
            -
                # {
         | 
| 196 | 
            -
                #     "role": 'user',
         | 
| 197 | 
            -
                #     "content": [
         | 
| 198 | 
            -
                #         {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": base64_string}},
         | 
| 199 | 
            -
                #         {"type": "text", "text": "What is in the image ?"}
         | 
| 200 | 
            -
                #     ]
         | 
| 201 | 
            -
                # }
         | 
| 202 | 
            -
                # ]
         | 
| 203 | 
            -
             | 
| 204 | 
            -
                # response = ant_client.messages.create(
         | 
| 205 | 
            -
                # model=MODEL_NAME,
         | 
| 206 | 
            -
                # max_tokens=2048,
         | 
| 207 | 
            -
                # messages=message_list
         | 
| 208 | 
            -
                # )
         | 
| 209 | 
             
                response = bedrock_runtime_client.invoke_model( 
         | 
| 210 | 
             
                        modelId= "anthropic.claude-3-haiku-20240307-v1:0",
         | 
| 211 | 
             
                        contentType = "application/json",
         | 
| @@ -234,9 +156,5 @@ def generate_image_captions_llm(base64_string,question): | |
| 234 | 
             
                                    }
         | 
| 235 | 
             
                                    ]
         | 
| 236 | 
             
                                     }))
         | 
| 237 | 
            -
                #print(response)
         | 
| 238 | 
             
                response_body = json.loads(response.get("body").read())['content'][0]['text']
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                #print(response_body)
         | 
| 241 | 
            -
                
         | 
| 242 | 
             
                return response_body
         | 
|  | |
| 24 |  | 
| 25 |  | 
| 26 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 27 | 
             
            def invoke_model(input):
         | 
| 28 | 
             
                response = bedrock_runtime_client.invoke_model(
         | 
| 29 | 
             
                    body=json.dumps({
         | 
|  | |
| 89 |  | 
| 90 | 
             
                    return (json.loads(res))['content'][0]['text']
         | 
| 91 |  | 
| 92 | 
            +
                  
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 93 | 
             
            def read_from_table(file,question):
         | 
| 94 | 
             
                print("started table analysis:")
         | 
| 95 | 
             
                print("-----------------------")
         | 
|  | |
| 115 | 
             
                    df = pd.read_csv(file,skipinitialspace = True, on_bad_lines='skip',delimiter = "`")
         | 
| 116 | 
             
                else:
         | 
| 117 | 
             
                    df = file
         | 
|  | |
| 118 | 
             
                agent = create_pandas_dataframe_agent(
         | 
| 119 | 
             
                         model, 
         | 
| 120 | 
             
                         df, 
         | 
|  | |
| 127 |  | 
| 128 | 
             
            def generate_image_captions_llm(base64_string,question):
         | 
| 129 |  | 
| 130 | 
            +
               
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 131 | 
             
                response = bedrock_runtime_client.invoke_model( 
         | 
| 132 | 
             
                        modelId= "anthropic.claude-3-haiku-20240307-v1:0",
         | 
| 133 | 
             
                        contentType = "application/json",
         | 
|  | |
| 156 | 
             
                                    }
         | 
| 157 | 
             
                                    ]
         | 
| 158 | 
             
                                     }))
         | 
|  | |
| 159 | 
             
                response_body = json.loads(response.get("body").read())['content'][0]['text']
         | 
|  | |
|  | |
|  | |
| 160 | 
             
                return response_body
         | 
    	
        utilities/re_ranker.py
    DELETED
    
    | @@ -1,127 +0,0 @@ | |
| 1 | 
            -
            import boto3
         | 
| 2 | 
            -
            from botocore.exceptions import ClientError
         | 
| 3 | 
            -
            import pprint
         | 
| 4 | 
            -
            import time
         | 
| 5 | 
            -
            import streamlit as st
         | 
| 6 | 
            -
            from sentence_transformers import CrossEncoder
         | 
| 7 | 
            -
             | 
| 8 | 
            -
            #model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512)
         | 
| 9 | 
            -
            ####### Add this Kendra Rescore ranking
         | 
| 10 | 
            -
            #kendra_ranking = boto3.client("kendra-ranking",region_name = 'us-east-1')
         | 
| 11 | 
            -
            #print("Create a rescore execution plan.")
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            # Provide a name for the rescore execution plan
         | 
| 14 | 
            -
            #name = "MyRescoreExecutionPlan"
         | 
| 15 | 
            -
            # Set your required additional capacity units
         | 
| 16 | 
            -
            # Don't set capacity units if you don't require more than 1 unit given by default
         | 
| 17 | 
            -
            #capacity_units = 2
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            # try:
         | 
| 20 | 
            -
            #     rescore_execution_plan_response = kendra_ranking.create_rescore_execution_plan(
         | 
| 21 | 
            -
            #         Name = name,
         | 
| 22 | 
            -
            #         CapacityUnits = {"RescoreCapacityUnits":capacity_units}
         | 
| 23 | 
            -
            #     )
         | 
| 24 | 
            -
             | 
| 25 | 
            -
            #     pprint.pprint(rescore_execution_plan_response)
         | 
| 26 | 
            -
             | 
| 27 | 
            -
            #     rescore_execution_plan_id = rescore_execution_plan_response["Id"]
         | 
| 28 | 
            -
             | 
| 29 | 
            -
            #     print("Wait for Amazon Kendra to create the rescore execution plan.")
         | 
| 30 | 
            -
             | 
| 31 | 
            -
            #     while True:
         | 
| 32 | 
            -
            #         # Get the details of the rescore execution plan, such as the status
         | 
| 33 | 
            -
            #         rescore_execution_plan_description = kendra_ranking.describe_rescore_execution_plan(
         | 
| 34 | 
            -
            #             Id = rescore_execution_plan_id
         | 
| 35 | 
            -
            #         )
         | 
| 36 | 
            -
            #         # When status is not CREATING quit.
         | 
| 37 | 
            -
            #         status = rescore_execution_plan_description["Status"]
         | 
| 38 | 
            -
            #         print(" Creating rescore execution plan. Status: "+status)
         | 
| 39 | 
            -
            #         time.sleep(60)
         | 
| 40 | 
            -
            #         if status != "CREATING":
         | 
| 41 | 
            -
            #             break
         | 
| 42 | 
            -
             | 
| 43 | 
            -
            # except ClientError as e:
         | 
| 44 | 
            -
            #         print("%s" % e)
         | 
| 45 | 
            -
             | 
| 46 | 
            -
            # print("Program ends.")
         | 
| 47 | 
            -
            #########################
         | 
| 48 | 
            -
             | 
| 49 | 
            -
            @st.cache_resource
         | 
| 50 | 
            -
            def re_rank(self_, rerank_type, search_type, question, answers):
         | 
| 51 | 
            -
               
         | 
| 52 | 
            -
                ans = []
         | 
| 53 | 
            -
                ids = []
         | 
| 54 | 
            -
                ques_ans = []
         | 
| 55 | 
            -
                query = question[0]['question']
         | 
| 56 | 
            -
                for i in answers[0]['answer']:
         | 
| 57 | 
            -
                    if(self_ == "search"):
         | 
| 58 | 
            -
                        
         | 
| 59 | 
            -
                        ans.append({
         | 
| 60 | 
            -
                                "Id": i['id'],
         | 
| 61 | 
            -
                                "Body": i["desc"],
         | 
| 62 | 
            -
                                "OriginalScore": i['score'],
         | 
| 63 | 
            -
                                "Title":i["desc"]
         | 
| 64 | 
            -
                                })
         | 
| 65 | 
            -
                        ids.append(i['id'])
         | 
| 66 | 
            -
                        ques_ans.append((query,i["desc"]))
         | 
| 67 | 
            -
                    
         | 
| 68 | 
            -
                    else:
         | 
| 69 | 
            -
                        ans.append({'text':i})
         | 
| 70 | 
            -
                        
         | 
| 71 | 
            -
                        ques_ans.append((query,i))
         | 
| 72 | 
            -
                    
         | 
| 73 | 
            -
                        
         | 
| 74 | 
            -
             | 
| 75 | 
            -
                re_ranked = [{}]
         | 
| 76 | 
            -
                ####### Add this Kendra Rescore ranking
         | 
| 77 | 
            -
                # if(rerank_type == 'Kendra Rescore'):
         | 
| 78 | 
            -
                #     rescore_response = kendra_ranking.rescore(
         | 
| 79 | 
            -
                #         RescoreExecutionPlanId = 'b2a4d4f3-98ff-4e17-8b69-4c61ed7d91eb',
         | 
| 80 | 
            -
                #         SearchQuery = query,
         | 
| 81 | 
            -
                #         Documents = ans
         | 
| 82 | 
            -
                #     )
         | 
| 83 | 
            -
                #     re_ranked[0]['answer']=[]
         | 
| 84 | 
            -
                #     for result in rescore_response["ResultItems"]:
         | 
| 85 | 
            -
             | 
| 86 | 
            -
                #         pos_ = ids.index(result['DocumentId'])
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                #         re_ranked[0]['answer'].append(answers[0]['answer'][pos_])
         | 
| 89 | 
            -
                #     re_ranked[0]['search_type']=search_type,
         | 
| 90 | 
            -
                #     re_ranked[0]['id'] = len(question)
         | 
| 91 | 
            -
                #     return re_ranked
         | 
| 92 | 
            -
                    
         | 
| 93 | 
            -
                # if(rerank_type == 'Cross Encoder'):
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                #     scores = model.predict(
         | 
| 96 | 
            -
                #                 ques_ans
         | 
| 97 | 
            -
                #                     )
         | 
| 98 | 
            -
                    
         | 
| 99 | 
            -
                #     index__ = 0
         | 
| 100 | 
            -
                #     for i in ans:
         | 
| 101 | 
            -
                #         i['new_score'] = scores[index__]
         | 
| 102 | 
            -
                #         index__ = index__+1
         | 
| 103 | 
            -
             | 
| 104 | 
            -
                #     ans_sorted = sorted(ans, key=lambda d: d['new_score'],reverse=True) 
         | 
| 105 | 
            -
                    
         | 
| 106 | 
            -
                    
         | 
| 107 | 
            -
                #     def retreive_only_text(item):
         | 
| 108 | 
            -
                #         return item['text']
         | 
| 109 | 
            -
                        
         | 
| 110 | 
            -
                #     if(self_ == 'rag'):
         | 
| 111 | 
            -
                #         return list(map(retreive_only_text, ans_sorted)) 
         | 
| 112 | 
            -
             | 
| 113 | 
            -
                   
         | 
| 114 | 
            -
                #     re_ranked[0]['answer']=[]
         | 
| 115 | 
            -
                #     for j in ans_sorted:
         | 
| 116 | 
            -
                #         pos_ = ids.index(j['Id'])
         | 
| 117 | 
            -
                #         re_ranked[0]['answer'].append(answers[0]['answer'][pos_])
         | 
| 118 | 
            -
                #     re_ranked[0]['search_type']= search_type,
         | 
| 119 | 
            -
                #     re_ranked[0]['id'] = len(question)
         | 
| 120 | 
            -
                #     return re_ranked
         | 
| 121 | 
            -
             | 
| 122 | 
            -
             | 
| 123 | 
            -
                
         | 
| 124 | 
            -
             | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
