Pablo276 commited on
Commit
31577a2
1 Parent(s): 8d4f54e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -15
app.py CHANGED
@@ -1,14 +1,19 @@
1
  import streamlit as st
2
- from langchain.vectorstores import faiss
3
  from langchain.text_splitter import CharacterTextSplitter
4
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
5
- from langchain.vectorstores import FAISS
6
  from langchain.document_loaders import TextLoader
7
  from langchain.embeddings import SentenceTransformerEmbeddings
8
  from tempfile import NamedTemporaryFile
9
  import os
10
  import shutil
11
-
 
 
 
 
 
 
 
12
  try:
13
  shutil.rmtree("tempDir")
14
  except :
@@ -17,28 +22,130 @@ try:
17
  os.mkdir("tempDir")
18
  except:
19
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def save_uploadedfile(uploadedfile):
21
 
 
22
  with open(os.path.join("tempDir",uploadedfile.name),"wb") as f:
23
  f.write(uploadedfile.getbuffer())
24
  return st.success("Saved File:{} to tempDir".format(uploadedfile.name))
25
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def main():
28
- st.set_page_config(page_title="chet with unipv")
29
- st.text_input("fai una domanda al tuo professore ")
 
 
 
 
 
 
 
 
 
 
 
30
  with st.sidebar:
31
- st.subheader("Your_faiss_index")
32
  documents=st.file_uploader("upload your faiss index here ",accept_multiple_files=True)
33
  if st.button("Procedi"):
34
- for document in documents:
35
- save_uploadedfile(document)
36
- #with st.spinner("sto processando i tuoi dati"):
37
- print(documents)
38
- query="chi è matteo salvini?"
39
- embeddings= HuggingFaceInstructEmbeddings(model_name="thenlper/gte-base")
40
- new_db = FAISS.load_local("tempDir", embeddings)
41
- docs = new_db.similarity_search(query)
42
- print(docs)
43
  if __name__=="__main__":
44
  main()
 
1
  import streamlit as st
 
2
  from langchain.text_splitter import CharacterTextSplitter
3
  from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
 
4
  from langchain.document_loaders import TextLoader
5
  from langchain.embeddings import SentenceTransformerEmbeddings
6
  from tempfile import NamedTemporaryFile
7
  import os
8
  import shutil
9
+ from typing import Any, List, Mapping, Optional
10
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
11
+ from langchain.llms.base import LLM
12
+ from gradio_client import Client
13
+ from langchain.memory import ConversationBufferMemory
14
+ from langchain.chains import ConversationalRetrievalChain
15
+ from langchain.vectorstores import FAISS
16
+ import time
17
  try:
18
  shutil.rmtree("tempDir")
19
  except :
 
22
  os.mkdir("tempDir")
23
  except:
24
  pass
25
+ css = '''
26
+ <style>
27
+ .chat-message {
28
+ padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
29
+ }
30
+ .chat-message.user {
31
+ background-color: #2b313e
32
+ }
33
+ .chat-message.bot {
34
+ background-color: #475063
35
+ }
36
+ .chat-message .avatar {
37
+ width: 20%;
38
+ }
39
+ .chat-message .avatar img {
40
+ max-width: 78px;
41
+ max-height: 78px;
42
+ border-radius: 50%;
43
+ object-fit: cover;
44
+ }
45
+ .chat-message .message {
46
+ width: 80%;
47
+ padding: 0 1.5rem;
48
+ color: #fff;
49
+ }
50
+ '''
51
+
52
+ bot_template = '''
53
+ <div class="chat-message bot">
54
+ <div class="avatar">
55
+ <img src="https://i.ibb.co/cN0nmSj/Screenshot-2023-05-28-at-02-37-21.png" style="max-height: 78px; max-width: 78px; border-radius: 50%; object-fit: cover;">
56
+ </div>
57
+ <div class="message">{{MSG}}</div>
58
+ </div>
59
+ '''
60
+
61
+ user_template = '''
62
+ <div class="chat-message user">
63
+ <div class="avatar">
64
+ <img src="https://cdn-icons-png.flaticon.com/512/149/149071.png">
65
+ </div>
66
+ <div class="message">{{MSG}}</div>
67
+ </div>
68
+ '''
69
+
70
+
71
  def save_uploadedfile(uploadedfile):
72
 
73
+
74
  with open(os.path.join("tempDir",uploadedfile.name),"wb") as f:
75
  f.write(uploadedfile.getbuffer())
76
  return st.success("Saved File:{} to tempDir".format(uploadedfile.name))
77
 
78
+ def ricerca_llama(domanda):
79
+ client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/")
80
+ risultato = client.predict( str(domanda),"you are a university professor, use appropriate language to answer students' questions .",0.1,2000,0.1,1.1,api_name="/chat")
81
+ print(domanda)
82
+ risultato=str(risultato).split("<")[0]
83
+ return risultato
84
+
85
 
86
+
87
+ class CustomLLM(LLM):
88
+ @property
89
+ def _llm_type(self) -> str:
90
+ return "custom"
91
+
92
+ def _call(
93
+ self,
94
+ prompt: str,
95
+ stop: Optional[List[str]] = None,
96
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
97
+ **kwargs: Any,
98
+ ) -> str:
99
+ if stop is not None:
100
+ raise ValueError("stop kwargs are not permitted.")
101
+
102
+ # Chiamata alla tua funzione API
103
+ risultato = ricerca_llama(prompt) # Assumendo che `prompt` sia la descrizione wiki
104
+
105
+ return risultato
106
+
107
+ def hande_user_input(user_question):
108
+ response=st.session_state.conversation({"question":user_question})
109
+ st.session_state.chat_history= response["chat_history"]
110
+ for i, message in enumerate(st.session_state.chat_history):
111
+ if i % 2== 0:
112
+ st.write(user_template.replace("{{MSG}}",message.content),unsafe_allow_html=True)
113
+ else:
114
+ st.write(bot_template.replace("{{MSG}}",message.content),unsafe_allow_html=True)
115
+
116
+
117
+ def get_conversation_chain(vectorstore):
118
+ llm=CustomLLM()
119
+ memory=ConversationBufferMemory(memory_key="chat_history",return_messages=True)
120
+ conversation_chain= ConversationalRetrievalChain.from_llm(llm=llm,
121
+ retriever=vectorstore.as_retriever(),
122
+ memory=memory)
123
+ return conversation_chain
124
  def main():
125
+
126
+ st.set_page_config(page_title="chat with unipv")
127
+
128
+ if "conversation" not in st.session_state:
129
+ st.session_state.conversation= None
130
+ if "chat_history" not in st.session_state:
131
+ st.session_state.chat_history= None
132
+
133
+ st.write(css,unsafe_allow_html=True)
134
+ user_input=st.text_input("fai una domanda al tuo professore ")
135
+ if user_input:
136
+ hande_user_input(user_input)
137
+
138
  with st.sidebar:
139
+ st.subheader("Your faiss index")
140
  documents=st.file_uploader("upload your faiss index here ",accept_multiple_files=True)
141
  if st.button("Procedi"):
142
+ with st.spinner("sto processando i tuoi dati"):
143
+ for document in documents:
144
+ save_uploadedfile(document)
145
+ time.sleep(1)
146
+ embeddings= HuggingFaceInstructEmbeddings(model_name="thenlper/gte-base")
147
+ new_db = FAISS.load_local("tempDir", embeddings)
148
+ st.session_state.conversation=get_conversation_chain(new_db)
149
+ #conversation=get_conversation_chain(new_db)
 
150
  if __name__=="__main__":
151
  main()