sivan22's picture
Update app.py
73d6aa7 verified
raw
history blame contribute delete
No virus
4.7 kB
import streamlit as st
from streamlit.logger import get_logger
import datasets
import pandas as pd
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, SystemMessage
from sentence_transformers import util
from torch import tensor
from io import StringIO
LOGGER = get_logger(__name__)
@st.cache_data
def get_df(uploaded_file) ->object:
if uploaded_file is None:
return None
stringio = StringIO(uploaded_file.getvalue().decode("utf-8"))
string_data = stringio.read()
df = pd.DataFrame(string_data.split('\n'), columns=['text'])
return df
@st.cache_data
def get_embeddings(df,_embeddings_model) ->object:
df['embeddings'] = df['text'].apply(lambda x: _embeddings_model.embed_query('passage: '+ x))
return df
@st.cache_resource
def get_model()->object:
model_name = "intfloat/multilingual-e5-large"
model_kwargs = {'device': 'cpu'} #'cpu' or 'cuda'
encode_kwargs = {'normalize_embeddings': True}
embeddings_model = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
return embeddings_model
@st.cache_resource
def get_chat_api(api_key:str):
chat = ChatOpenAI(model="gpt-3.5-turbo-16k", api_key=api_key)
return chat
def get_results(embeddings_model,input,df,num_of_results) -> pd.DataFrame:
embeddings = embeddings_model.embed_query('query: '+ input)
hits = util.semantic_search(tensor(embeddings), tensor(df['embeddings'].tolist()), top_k=num_of_results)
hit_list = [hit['corpus_id'] for hit in hits[0]]
return df.iloc[hit_list]
def get_llm_results(query,chat,results):
prompt_template = PromptTemplate.from_template(
"""
your misssion is to rank the given answers based on their relevance to the given question.
Provide a relevancy score between 0 (not relevant) and 1 (highly relevant) for each possible answer.
the results should be in the following JSON format: "answer": "score", "answer": "score" while answer is the possible answer's text and score is the relevancy score.
the question is: {query}
the possible answers are:
{answers}
""" )
messages = [
SystemMessage(content="""
You're a helpful assistant.
Return a JSON formatted string.
"""),
HumanMessage(content=prompt_template.format(query=query, answers=str.join('\n', results['text'].head(10).tolist()))),
]
response = chat.invoke(messages)
llm_results_df = pd.read_json(response.content, orient='index')
llm_results_df.rename(columns={0: 'score'}, inplace=True)
llm_results_df.sort_values(by='score', ascending=False, inplace=True)
return llm_results_df
def run():
st.set_page_config(
page_title=" ื—ื™ืคื•ืฉ ืกืžื ื˜ื™",
page_icon="",
layout="wide",
initial_sidebar_state="expanded"
)
st.write("# ื—ื™ืคื•ืฉ ื—ื›ื ")
st.write('ื ื™ืชืŸ ืœื”ืขืœื•ืช ื›ืœ ืงื•ื‘ืฅ ื˜ืงืกื˜, ืœื”ืžืชื™ืŸ ืœื™ืฆื™ืจืช ื”ืื™ื ื“ืงืก ื•ืœืื—ืจ ืžื›ืŸ ืœื—ืคืฉ ื‘ืฉืคื” ื—ื•ืคืฉื™ืช')
st.write('ื™ืฆื™ืจืช ื”ืื™ื ื“ืงืก ืขืฉื•ื™ื” ืœืงื—ืช ืžืกืคืจ ื“ืงื•ืช, ื•ืชืœื•ื™ื” ื‘ื’ื•ื“ืœ ื”ืงื•ื‘ืฅ')
uploaded_file = st.file_uploader('ื”ืขืœื” ืงื•ื‘ืฅ', type=['txt'], on_change=run)
embeddings_model = get_model()
df = get_df(uploaded_file)
if df is None:
st.write("ืœื ื”ื•ืขืœื” ืงื•ื‘ืฅ")
else:
df = get_embeddings(df,embeddings_model)
user_input = st.text_input('ื›ืชื•ื‘ ื›ืืŸ ืืช ืฉืืœืชืš', placeholder='')
num_of_results = st.sidebar.slider('ืžืกืคืจ ื”ืชื•ืฆืื•ืช ืฉื‘ืจืฆื•ื ืš ืœื”ืฆื™ื’:',1,25,5)
use_llm = st.sidebar.checkbox("ื”ืฉืชืžืฉ ื‘ืžื•ื“ืœ ืฉืคื” ื›ื“ื™ ืœืฉืคืจ ืชื•ืฆืื•ืช", False)
openAikey = st.sidebar.text_input("OpenAI API key", type="password")
if (st.button('ื—ืคืฉ') or user_input) and user_input!="" and df is not None:
results = get_results(embeddings_model,user_input,df,num_of_results)
if use_llm:
if openAikey == None or openAikey=="":
st.write("ืœื ื”ื•ื›ื ืก ืžืคืชื— ืฉืœ OpenAI")
else:
chat = get_chat_api(openAikey)
llm_results = get_llm_results(user_input,chat,results)
st.write(llm_results)
else:
st.write(results['text'].head(10))
if __name__ == "__main__":
run()