sivan22 commited on
Commit
3d33fb5
โ€ข
1 Parent(s): b067875

init from PC

Browse files
Files changed (5) hide show
  1. __init__.py +13 -0
  2. app.py +138 -0
  3. requirements.txt +9 -0
  4. run.bat +2 -0
  5. utils.py +28 -0
__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit.logger import get_logger
3
+ import datasets
4
+ import pandas as pd
5
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.messages import HumanMessage, SystemMessage
9
+ from sentence_transformers import util
10
+ from torch import tensor
11
+ from io import StringIO
12
+
13
+
14
+ LOGGER = get_logger(__name__)
15
+
16
+
17
+ @st.cache_data
18
+ def get_df(uploaded_file) ->object:
19
+ if uploaded_file is None:
20
+ return None
21
+ stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
22
+ string_data = stringio.read()
23
+ df = pd.DataFrame(string_data.split('\n'), columns=['text'])
24
+ return df
25
+
26
+ @st.cache_data
27
+ def get_embeddings(df,_embeddings_model) ->object:
28
+ df['embeddings'] = df['text'].apply(lambda x: _embeddings_model.embed_query('passage: '+ x))
29
+ return df
30
+
31
+ @st.cache_resource
32
+ def get_model()->object:
33
+ model_name = "intfloat/multilingual-e5-large"
34
+ model_kwargs = {'device': 'cuda'} #'cpu' or 'cuda'
35
+ encode_kwargs = {'normalize_embeddings': True}
36
+ embeddings_model = HuggingFaceEmbeddings(
37
+ model_name=model_name,
38
+ model_kwargs=model_kwargs,
39
+ encode_kwargs=encode_kwargs
40
+ )
41
+ return embeddings_model
42
+
43
+ @st.cache_resource
44
+ def get_chat_api(api_key:str):
45
+ chat = ChatOpenAI(model="gpt-3.5-turbo-16k", api_key=api_key)
46
+ return chat
47
+
48
+
49
+ def get_results(embeddings_model,input,df,num_of_results) -> pd.DataFrame:
50
+ embeddings = embeddings_model.embed_query('query: '+ input)
51
+ hits = util.semantic_search(tensor(embeddings), tensor(df['embeddings'].tolist()), top_k=num_of_results)
52
+ hit_list = [hit['corpus_id'] for hit in hits[0]]
53
+ return df.iloc[hit_list]
54
+
55
+ def get_llm_results(query,chat,results):
56
+
57
+ prompt_template = PromptTemplate.from_template(
58
+ """
59
+ your misssion is to rank the given answers based on their relevance to the given question.
60
+ Provide a relevancy score between 0 (not relevant) and 1 (highly relevant) for each possible answer.
61
+ the results should be in the following JSON format: "answer": "score", "answer": "score" while answer is the possible answer's text and score is the relevancy score.
62
+
63
+ the question is: {query}
64
+
65
+ the possible answers are:
66
+ {answers}
67
+
68
+ """ )
69
+
70
+ messages = [
71
+ SystemMessage(content="""
72
+ You're a helpful assistant.
73
+ Return a JSON formatted string.
74
+ """),
75
+ HumanMessage(content=prompt_template.format(query=query, answers=str.join('\n', results['text'].head(10).tolist()))),
76
+ ]
77
+
78
+ response = chat.invoke(messages)
79
+ llm_results_df = pd.read_json(response.content, orient='index')
80
+ llm_results_df.rename(columns={0: 'score'}, inplace=True)
81
+ llm_results_df.sort_values(by='score', ascending=False, inplace=True)
82
+ return llm_results_df
83
+
84
+
85
+
86
+ def run():
87
+ st.set_page_config(
88
+ page_title=" ื—ื™ืคื•ืฉ ืกืžื ื˜ื™",
89
+ page_icon="",
90
+ layout="wide",
91
+ initial_sidebar_state="expanded"
92
+ )
93
+
94
+ st.write("# ื—ื™ืคื•ืฉ ื—ื›ื ")
95
+ st.write('ื ื™ืชืŸ ืœื”ืขืœื•ืช ื›ืœ ืงื•ื‘ืฅ ื˜ืงืกื˜, ืœื”ืžืชื™ืŸ ืœื™ืฆื™ืจืช ื”ืื™ื ื“ืงืก ื•ืœืื—ืจ ืžื›ืŸ ืœื—ืคืฉ ื‘ืฉืคื” ื—ื•ืคืฉื™ืช')
96
+ st.write('ื™ืฆื™ืจืช ื”ืื™ื ื“ืงืก ืขืฉื•ื™ื” ืœืงื—ืช ืžืกืคืจ ื“ืงื•ืช, ื•ืชืœื•ื™ื” ื‘ื’ื•ื“ืœ ื”ืงื•ื‘ืฅ')
97
+
98
+ uploaded_file = st.file_uploader('ื”ืขืœื” ืงื•ื‘ืฅ', type=['txt'], on_change=run)
99
+
100
+
101
+
102
+
103
+
104
+ embeddings_model = get_model()
105
+ df = get_df(uploaded_file)
106
+ if df is None:
107
+ st.write("ืœื ื”ื•ืขืœื” ืงื•ื‘ืฅ")
108
+ else:
109
+ df = get_embeddings(df,embeddings_model)
110
+
111
+
112
+
113
+ user_input = st.text_input('ื›ืชื•ื‘ ื›ืืŸ ืืช ืฉืืœืชืš', placeholder='')
114
+ num_of_results = st.sidebar.slider('ืžืกืคืจ ื”ืชื•ืฆืื•ืช ืฉื‘ืจืฆื•ื ืš ืœื”ืฆื™ื’:',1,25,5)
115
+ use_llm = st.sidebar.checkbox("ื”ืฉืชืžืฉ ื‘ืžื•ื“ืœ ืฉืคื” ื›ื“ื™ ืœืฉืคืจ ืชื•ืฆืื•ืช", False)
116
+ openAikey = st.sidebar.text_input("OpenAI API key", type="password")
117
+
118
+
119
+ if (st.button('ื—ืคืฉ') or user_input) and user_input!="" and df is not None:
120
+
121
+
122
+ results = get_results(embeddings_model,user_input,df,num_of_results)
123
+
124
+ if use_llm:
125
+ if openAikey == None or openAikey=="":
126
+ st.write("ืœื ื”ื•ื›ื ืก ืžืคืชื— ืฉืœ OpenAI")
127
+
128
+ else:
129
+ chat = get_chat_api(openAikey)
130
+ llm_results = get_llm_results(user_input,chat,results)
131
+ st.write(llm_results)
132
+
133
+ else:
134
+ st.write(results.head(10))
135
+
136
+
137
+ if __name__ == "__main__":
138
+ run()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ streamlit
3
+ torch
4
+ transformers
5
+ datasets
6
+ langchain_huggingface
7
+ langchain_openai
8
+ langchain
9
+ sentence_transformers
run.bat ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pip install -r requirements.txt
2
+ streamlit run app.py
utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import textwrap
17
+
18
+ import streamlit as st
19
+
20
+
21
+ def show_code(demo):
22
+ """Showing the code of the demo."""
23
+ show_code = st.sidebar.checkbox("Show code", True)
24
+ if show_code:
25
+ # Showing the code of the demo.
26
+ st.markdown("## Code")
27
+ sourcelines, _ = inspect.getsourcelines(demo)
28
+ st.code(textwrap.dedent("".join(sourcelines[1:])))