Kieran Gookey commited on
Commit
242bba0
·
1 Parent(s): df26c41

Tried different approach

Browse files
Files changed (1) hide show
  1. app.py +83 -35
app.py CHANGED
@@ -10,52 +10,100 @@ from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter
10
 
11
  inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
12
 
13
- llm = HuggingFaceInferenceAPI(
14
- model_name="mistralai/Mistral-7B-Instruct-v0.2", token=inference_api_key)
15
 
16
- embed_model = HuggingFaceInferenceAPIEmbedding(
17
- model_name="Gooly/gte-small-en-fine-tuned-e-commerce",
18
- token=inference_api_key,
19
- model_kwargs={"device": ""},
20
- encode_kwargs={"normalize_embeddings": True},
21
- )
22
-
23
- service_context = ServiceContext.from_defaults(
24
- embed_model=embed_model, llm=llm)
25
 
26
  html_file = st.file_uploader("Upload a html file", type=["html"])
27
 
28
- if html_file is not None:
29
- stringio = StringIO(html_file.getvalue().decode("utf-8"))
30
- string_data = stringio.read()
31
- with st.expander("Uploaded HTML"):
32
- st.write(string_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- document_id = str(uuid.uuid4())
35
 
36
- document = Document(text=string_data)
37
- document.metadata["id"] = document_id
38
- documents = [document]
39
 
40
- filters = MetadataFilters(
41
- filters=[ExactMatchFilter(key="id", value=document_id)])
42
 
43
- index = VectorStoreIndex.from_documents(
44
- documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
45
 
46
- retriever = index.as_retriever()
47
 
48
- ranked_nodes = retriever.retrieve(
49
- "Get me all the information about the product")
50
 
51
- with st.expander("Ranked Nodes"):
52
- for node in ranked_nodes:
53
- st.write(node.node.get_content(), "-> Score:", node.score)
54
 
55
- query_engine = index.as_query_engine(
56
- filters=filters, service_context=service_context)
57
 
58
- response = query_engine.query(
59
- "Get me all the information about the product")
60
 
61
- st.write(response)
 
10
 
11
  inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]
12
 
13
+ embed_model_name = st.text_input(
14
+ 'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")
15
 
16
+ llm_model_name = st.text_input(
17
+ 'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")
 
 
 
 
 
 
 
18
 
19
  html_file = st.file_uploader("Upload a html file", type=["html"])
20
 
21
+ if st.button('Start Pipeline'):
22
+ if html_file is not None and embed_model_name is not None and llm_model_name is not None:
23
+ st.write('Running Pipeline')
24
+ llm = HuggingFaceInferenceAPI(
25
+ model_name=llm_model_name, token=inference_api_key)
26
+
27
+ embed_model = HuggingFaceInferenceAPIEmbedding(
28
+ model_name=embed_model_name,
29
+ token=inference_api_key,
30
+ model_kwargs={"device": ""},
31
+ encode_kwargs={"normalize_embeddings": True},
32
+ )
33
+
34
+ service_context = ServiceContext.from_defaults(
35
+ embed_model=embed_model, llm=llm)
36
+
37
+ stringio = StringIO(html_file.getvalue().decode("utf-8"))
38
+ string_data = stringio.read()
39
+ with st.expander("Uploaded HTML"):
40
+ st.write(string_data)
41
+
42
+ document_id = str(uuid.uuid4())
43
+
44
+ document = Document(text=string_data)
45
+ document.metadata["id"] = document_id
46
+ documents = [document]
47
+
48
+ filters = MetadataFilters(
49
+ filters=[ExactMatchFilter(key="id", value=document_id)])
50
+
51
+ index = VectorStoreIndex.from_documents(
52
+ documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
53
+
54
+ retriever = index.as_retriever()
55
+
56
+ ranked_nodes = retriever.retrieve(
57
+ "Get me all the information about the product")
58
+
59
+ with st.expander("Ranked Nodes"):
60
+ for node in ranked_nodes:
61
+ st.write(node.node.get_content(), "-> Score:", node.score)
62
+
63
+ query_engine = index.as_query_engine(
64
+ filters=filters, service_context=service_context)
65
+
66
+ response = query_engine.query(
67
+ "Get me all the information about the product")
68
+
69
+ st.write(response)
70
+
71
+ else:
72
+ st.error('Please fill in all the fields')
73
+ else:
74
+ st.write('Press start to begin')
75
+
76
+ # if html_file is not None:
77
+ # stringio = StringIO(html_file.getvalue().decode("utf-8"))
78
+ # string_data = stringio.read()
79
+ # with st.expander("Uploaded HTML"):
80
+ # st.write(string_data)
81
 
82
+ # document_id = str(uuid.uuid4())
83
 
84
+ # document = Document(text=string_data)
85
+ # document.metadata["id"] = document_id
86
+ # documents = [document]
87
 
88
+ # filters = MetadataFilters(
89
+ # filters=[ExactMatchFilter(key="id", value=document_id)])
90
 
91
+ # index = VectorStoreIndex.from_documents(
92
+ # documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)
93
 
94
+ # retriever = index.as_retriever()
95
 
96
+ # ranked_nodes = retriever.retrieve(
97
+ # "Get me all the information about the product")
98
 
99
+ # with st.expander("Ranked Nodes"):
100
+ # for node in ranked_nodes:
101
+ # st.write(node.node.get_content(), "-> Score:", node.score)
102
 
103
+ # query_engine = index.as_query_engine(
104
+ # filters=filters, service_context=service_context)
105
 
106
+ # response = query_engine.query(
107
+ # "Get me all the information about the product")
108
 
109
+ # st.write(response)