|
import streamlit as st |
|
from langchain.text_splitter import CharacterTextSplitter |
|
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings |
|
from langchain.document_loaders import TextLoader |
|
from langchain.embeddings import SentenceTransformerEmbeddings |
|
from tempfile import NamedTemporaryFile |
|
import os |
|
import shutil |
|
from typing import Any, List, Mapping, Optional |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.llms.base import LLM |
|
from gradio_client import Client |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.vectorstores import FAISS |
|
import time |
|
try: |
|
shutil.rmtree("tempDir") |
|
except : |
|
pass |
|
try: |
|
os.mkdir("tempDir") |
|
except: |
|
pass |
|
css = ''' |
|
<style> |
|
.chat-message { |
|
padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex |
|
} |
|
.chat-message.user { |
|
background-color: #2b313e |
|
} |
|
.chat-message.bot { |
|
background-color: #475063 |
|
} |
|
.chat-message .avatar { |
|
width: 20%; |
|
} |
|
.chat-message .avatar img { |
|
max-width: 78px; |
|
max-height: 78px; |
|
border-radius: 50%; |
|
object-fit: cover; |
|
} |
|
.chat-message .message { |
|
width: 80%; |
|
padding: 0 1.5rem; |
|
color: #fff; |
|
} |
|
''' |
|
|
|
bot_template = ''' |
|
<div class="chat-message bot"> |
|
<div class="avatar"> |
|
<img src="https://i.ibb.co/cN0nmSj/Screenshot-2023-05-28-at-02-37-21.png" style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;"> |
|
</div> |
|
<div class="message">{{MSG}}</div> |
|
</div> |
|
''' |
|
|
|
user_template = ''' |
|
<div class="chat-message user"> |
|
<div class="avatar"> |
|
<img src="https://cdn-icons-png.flaticon.com/512/149/149071.png"> |
|
</div> |
|
<div class="message">{{MSG}}</div> |
|
</div> |
|
''' |
|
|
|
|
|
def save_uploadedfile(uploadedfile): |
|
|
|
|
|
with open(os.path.join("tempDir",uploadedfile.name),"wb") as f: |
|
f.write(uploadedfile.getbuffer()) |
|
return st.success("Saved File:{} to tempDir".format(uploadedfile.name)) |
|
|
|
def ricerca_llama(domanda): |
|
domanda=domanda.replace("Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.","Act as if you were a university professor. The following are excerpts from your lectures; emulate the way of speaking/writing and respond to the question at the end. If there is no answer to the question in the context, simply say that you don't know how to answer. NEVER MAKE UP AN ANSWER; USE ONLY AND EXCLUSIVELY THE CONTEXT.") |
|
client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/") |
|
risultato = client.predict( str(domanda),"you are a university professor, use appropriate language to answer students' questions .",0.0,2000,0.1,1.1,api_name="/chat") |
|
print(domanda) |
|
risultato=str(risultato).split("<")[0] |
|
return risultato |
|
|
|
|
|
|
|
class CustomLLM(LLM): |
|
@property |
|
def _llm_type(self) -> str: |
|
return "custom" |
|
|
|
def _call( |
|
self, |
|
prompt: str, |
|
stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any, |
|
) -> str: |
|
if stop is not None: |
|
raise ValueError("stop kwargs are not permitted.") |
|
|
|
|
|
risultato = ricerca_llama(prompt) |
|
|
|
return risultato |
|
|
|
def hande_user_input(user_question): |
|
response=st.session_state.conversation({"question":user_question}) |
|
st.session_state.chat_history= response["chat_history"] |
|
for i, message in enumerate(st.session_state.chat_history): |
|
if i % 2== 0: |
|
st.write(user_template.replace("{{MSG}}",message.content),unsafe_allow_html=True) |
|
else: |
|
st.write(bot_template.replace("{{MSG}}",message.content),unsafe_allow_html=True) |
|
|
|
|
|
def get_conversation_chain(vectorstore): |
|
llm=CustomLLM() |
|
memory=ConversationBufferMemory(memory_key="chat_history",return_messages=True) |
|
conversation_chain= ConversationalRetrievalChain.from_llm(llm=llm, |
|
retriever=vectorstore.as_retriever(), |
|
memory=memory) |
|
return conversation_chain |
|
def main(): |
|
|
|
st.set_page_config(page_title="chat with unipv") |
|
|
|
if "conversation" not in st.session_state: |
|
st.session_state.conversation= None |
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history= None |
|
|
|
st.write(css,unsafe_allow_html=True) |
|
user_input=st.text_input("fai una domanda al tuo professore ") |
|
if user_input: |
|
hande_user_input(user_input) |
|
|
|
with st.sidebar: |
|
st.subheader("Your faiss index") |
|
documents=st.file_uploader("upload your faiss index here ",accept_multiple_files=True) |
|
if st.button("Procedi"): |
|
with st.spinner("sto processando i tuoi dati"): |
|
for document in documents: |
|
save_uploadedfile(document) |
|
time.sleep(1) |
|
embeddings= HuggingFaceInstructEmbeddings(model_name="thenlper/gte-base") |
|
new_db = FAISS.load_local("tempDir", embeddings) |
|
st.session_state.conversation=get_conversation_chain(new_db) |
|
|
|
if __name__=="__main__": |
|
|
|
main() |