Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		Kieran Gookey
		
	commited on
		
		
					Commit 
							
							·
						
						6ad144b
	
1
								Parent(s):
							
							277b244
								
Set a different embedding model
Browse files
    	
        app.py
    CHANGED
    
    | @@ -10,104 +10,146 @@ from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter | |
| 10 |  | 
| 11 | 
             
            inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
         | 
| 12 |  | 
| 13 | 
            -
            embed_model_name = st.text_input(
         | 
| 14 | 
            -
             | 
| 15 |  | 
| 16 | 
            -
            llm_model_name = st.text_input(
         | 
| 17 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
             
            query = st.text_input(
         | 
| 20 | 
            -
                'Query', "What is the price of the product?" | 
|  | |
| 21 |  | 
| 22 | 
             
            html_file = st.file_uploader("Upload a html file", type=["html"])
         | 
| 23 |  | 
| 24 | 
            -
            if  | 
| 25 | 
            -
                 | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 |  | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 |  | 
| 37 | 
            -
             | 
| 38 | 
            -
             | 
| 39 |  | 
| 40 | 
            -
             | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 |  | 
| 45 | 
            -
             | 
| 46 |  | 
| 47 | 
            -
             | 
| 48 | 
            -
             | 
| 49 | 
            -
             | 
| 50 |  | 
| 51 | 
            -
             | 
| 52 | 
            -
             | 
| 53 |  | 
| 54 | 
            -
             | 
| 55 | 
            -
             | 
| 56 |  | 
| 57 | 
            -
             | 
| 58 |  | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 |  | 
| 62 | 
            -
             | 
| 63 | 
            -
             | 
| 64 | 
            -
             | 
| 65 |  | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 |  | 
| 69 | 
            -
             | 
| 70 |  | 
| 71 | 
            -
             | 
| 72 |  | 
| 73 | 
            -
             | 
| 74 |  | 
| 75 | 
            -
             | 
| 76 | 
            -
             | 
| 77 | 
            -
            else:
         | 
| 78 | 
            -
             | 
| 79 |  | 
| 80 | 
            -
            # if html_file is not None:
         | 
| 81 | 
            -
            #     stringio = StringIO(html_file.getvalue().decode("utf-8"))
         | 
| 82 | 
            -
            #     string_data = stringio.read()
         | 
| 83 | 
            -
            #     with st.expander("Uploaded HTML"):
         | 
| 84 | 
            -
            #         st.write(string_data)
         | 
| 85 |  | 
| 86 | 
            -
            #     document_id = str(uuid.uuid4())
         | 
| 87 |  | 
| 88 | 
            -
            #     document = Document(text=string_data)
         | 
| 89 | 
            -
            #     document.metadata["id"] = document_id
         | 
| 90 | 
            -
            #     documents = [document]
         | 
| 91 |  | 
| 92 | 
            -
            #     filters = MetadataFilters(
         | 
| 93 | 
            -
            #         filters=[ExactMatchFilter(key="id", value=document_id)])
         | 
| 94 |  | 
| 95 | 
            -
            #     index = VectorStoreIndex.from_documents(
         | 
| 96 | 
            -
            #         documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
         | 
| 97 |  | 
| 98 | 
            -
            #     retriever = index.as_retriever()
         | 
| 99 |  | 
| 100 | 
            -
            #     ranked_nodes = retriever.retrieve(
         | 
| 101 | 
            -
            #         "Get me all the information about the product")
         | 
| 102 |  | 
| 103 | 
            -
            #     with st.expander("Ranked Nodes"):
         | 
| 104 | 
            -
            #         for node in ranked_nodes:
         | 
| 105 | 
            -
            #             st.write(node.node.get_content(), "-> Score:", node.score)
         | 
| 106 |  | 
| 107 | 
            -
            #     query_engine = index.as_query_engine(
         | 
| 108 | 
            -
            #         filters=filters, service_context=service_context)
         | 
| 109 |  | 
| 110 | 
            -
            #     response = query_engine.query(
         | 
| 111 | 
            -
            #         "Get me all the information about the product")
         | 
| 112 |  | 
| 113 | 
            -
            #     st.write(response)
         | 
|  | |
| 10 |  | 
| 11 | 
             
            inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
         | 
| 12 |  | 
| 13 | 
            +
            # embed_model_name = st.text_input(
         | 
| 14 | 
            +
            #     'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")
         | 
| 15 |  | 
| 16 | 
            +
            # llm_model_name = st.text_input(
         | 
| 17 | 
            +
            #     'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            embed_model_name = "jinaai/jina-embedding-s-en-v1"
         | 
| 20 | 
            +
            llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            llm = HuggingFaceInferenceAPI(
         | 
| 23 | 
            +
                model_name=llm_model_name, token=inference_api_key)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            embed_model = HuggingFaceInferenceAPIEmbedding(
         | 
| 26 | 
            +
                model_name=embed_model_name,
         | 
| 27 | 
            +
                token=inference_api_key,
         | 
| 28 | 
            +
                model_kwargs={"device": ""},
         | 
| 29 | 
            +
                encode_kwargs={"normalize_embeddings": True},
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            service_context = ServiceContext.from_defaults(
         | 
| 33 | 
            +
                embed_model=embed_model, llm=llm)
         | 
| 34 |  | 
| 35 | 
             
            query = st.text_input(
         | 
| 36 | 
            +
                'Query', "What is the price of the product?"
         | 
| 37 | 
            +
            )
         | 
| 38 |  | 
| 39 | 
             
            html_file = st.file_uploader("Upload a html file", type=["html"])
         | 
| 40 |  | 
| 41 | 
            +
            if html_file is not None:
         | 
| 42 | 
            +
                stringio = StringIO(html_file.getvalue().decode("utf-8"))
         | 
| 43 | 
            +
                string_data = stringio.read()
         | 
| 44 | 
            +
                with st.expander("Uploaded HTML"):
         | 
| 45 | 
            +
                    st.write(string_data)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                document_id = str(uuid.uuid4())
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                document = Document(text=string_data)
         | 
| 50 | 
            +
                document.metadata["id"] = document_id
         | 
| 51 | 
            +
                documents = [document]
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                filters = MetadataFilters(
         | 
| 54 | 
            +
                    filters=[ExactMatchFilter(key="id", value=document_id)])
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                index = VectorStoreIndex.from_documents(
         | 
| 57 | 
            +
                    documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                query_engine = index.as_query_engine(
         | 
| 60 | 
            +
                    filters=filters, service_context=service_context)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                response = query_engine.query(query)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                st.write(response.response)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            # if st.button('Start Pipeline'):
         | 
| 67 | 
            +
            #     if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
         | 
| 68 | 
            +
            #         st.write('Running Pipeline')
         | 
| 69 | 
            +
            #         llm = HuggingFaceInferenceAPI(
         | 
| 70 | 
            +
            #             model_name=llm_model_name, token=inference_api_key)
         | 
| 71 |  | 
| 72 | 
            +
            #         embed_model = HuggingFaceInferenceAPIEmbedding(
         | 
| 73 | 
            +
            #             model_name=embed_model_name,
         | 
| 74 | 
            +
            #             token=inference_api_key,
         | 
| 75 | 
            +
            #             model_kwargs={"device": ""},
         | 
| 76 | 
            +
            #             encode_kwargs={"normalize_embeddings": True},
         | 
| 77 | 
            +
            #         )
         | 
| 78 |  | 
| 79 | 
            +
            #         service_context = ServiceContext.from_defaults(
         | 
| 80 | 
            +
            #             embed_model=embed_model, llm=llm)
         | 
| 81 |  | 
| 82 | 
            +
            #         stringio = StringIO(html_file.getvalue().decode("utf-8"))
         | 
| 83 | 
            +
            #         string_data = stringio.read()
         | 
| 84 | 
            +
            #         with st.expander("Uploaded HTML"):
         | 
| 85 | 
            +
            #             st.write(string_data)
         | 
| 86 |  | 
| 87 | 
            +
            #         document_id = str(uuid.uuid4())
         | 
| 88 |  | 
| 89 | 
            +
            #         document = Document(text=string_data)
         | 
| 90 | 
            +
            #         document.metadata["id"] = document_id
         | 
| 91 | 
            +
            #         documents = [document]
         | 
| 92 |  | 
| 93 | 
            +
            #         filters = MetadataFilters(
         | 
| 94 | 
            +
            #             filters=[ExactMatchFilter(key="id", value=document_id)])
         | 
| 95 |  | 
| 96 | 
            +
            #         index = VectorStoreIndex.from_documents(
         | 
| 97 | 
            +
            #             documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
         | 
| 98 |  | 
| 99 | 
            +
            #         retriever = index.as_retriever()
         | 
| 100 |  | 
| 101 | 
            +
            #         ranked_nodes = retriever.retrieve(
         | 
| 102 | 
            +
            #             query)
         | 
| 103 |  | 
| 104 | 
            +
            #         with st.expander("Ranked Nodes"):
         | 
| 105 | 
            +
            #             for node in ranked_nodes:
         | 
| 106 | 
            +
            #                 st.write(node.node.get_content(), "-> Score:", node.score)
         | 
| 107 |  | 
| 108 | 
            +
            #         query_engine = index.as_query_engine(
         | 
| 109 | 
            +
            #             filters=filters, service_context=service_context)
         | 
| 110 |  | 
| 111 | 
            +
            #         response = query_engine.query(query)
         | 
| 112 |  | 
| 113 | 
            +
            #         st.write(response.response)
         | 
| 114 |  | 
| 115 | 
            +
            #         st.write(response.source_nodes)
         | 
| 116 |  | 
| 117 | 
            +
            #     else:
         | 
| 118 | 
            +
            #         st.error('Please fill in all the fields')
         | 
| 119 | 
            +
            # else:
         | 
| 120 | 
            +
            #     st.write('Press start to begin')
         | 
| 121 |  | 
| 122 | 
            +
            # # if html_file is not None:
         | 
| 123 | 
            +
            # #     stringio = StringIO(html_file.getvalue().decode("utf-8"))
         | 
| 124 | 
            +
            # #     string_data = stringio.read()
         | 
| 125 | 
            +
            # #     with st.expander("Uploaded HTML"):
         | 
| 126 | 
            +
            # #         st.write(string_data)
         | 
| 127 |  | 
| 128 | 
            +
            # #     document_id = str(uuid.uuid4())
         | 
| 129 |  | 
| 130 | 
            +
            # #     document = Document(text=string_data)
         | 
| 131 | 
            +
            # #     document.metadata["id"] = document_id
         | 
| 132 | 
            +
            # #     documents = [document]
         | 
| 133 |  | 
| 134 | 
            +
            # #     filters = MetadataFilters(
         | 
| 135 | 
            +
            # #         filters=[ExactMatchFilter(key="id", value=document_id)])
         | 
| 136 |  | 
| 137 | 
            +
            # #     index = VectorStoreIndex.from_documents(
         | 
| 138 | 
            +
            # #         documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
         | 
| 139 |  | 
| 140 | 
            +
            # #     retriever = index.as_retriever()
         | 
| 141 |  | 
| 142 | 
            +
            # #     ranked_nodes = retriever.retrieve(
         | 
| 143 | 
            +
            # #         "Get me all the information about the product")
         | 
| 144 |  | 
| 145 | 
            +
            # #     with st.expander("Ranked Nodes"):
         | 
| 146 | 
            +
            # #         for node in ranked_nodes:
         | 
| 147 | 
            +
            # #             st.write(node.node.get_content(), "-> Score:", node.score)
         | 
| 148 |  | 
| 149 | 
            +
            # #     query_engine = index.as_query_engine(
         | 
| 150 | 
            +
            # #         filters=filters, service_context=service_context)
         | 
| 151 |  | 
| 152 | 
            +
            # #     response = query_engine.query(
         | 
| 153 | 
            +
            # #         "Get me all the information about the product")
         | 
| 154 |  | 
| 155 | 
            +
            # #     st.write(response)
         |