PCFISH commited on
Commit
f3b7497
β€’
1 Parent(s): af69459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -91
app.py CHANGED
@@ -29,54 +29,42 @@ def get_pdf_text(pdf_docs):
29
  # μ•„λž˜ ν…μŠ€νŠΈ μΆ”μΆœ ν•¨μˆ˜λ₯Ό μž‘μ„±
30
 
31
  def get_text_file(docs):
32
- if docs.type == 'text/plain':
33
- # ν…μŠ€νŠΈ 파일 (.txt)μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜
34
- return [docs.getvalue().decode('utf-8')]
35
- else:
36
- st.warning("Unsupported file type for get_text_file")
 
 
 
37
 
38
  def get_csv_file(docs):
39
- if docs.type == 'text/csv':
40
- # CSV 파일 (.csv)μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜
41
- csv_loader = CSVLoader(docs)
42
- csv_data = csv_loader.load()
43
- # CSV 파일의 각 행을 λ¬Έμžμ—΄λ‘œ λ³€ν™˜ν•˜μ—¬ λ°˜ν™˜
44
- return [' '.join(map(str, row)) for row in csv_data]
45
- else:
46
- st.warning("Unsupported file type for get_csv_file")
47
 
48
  def get_json_file(docs):
49
- if docs.type == 'application/json':
50
- # JSON 파일 (.json)μ—μ„œ ν…μŠ€νŠΈλ₯Ό μΆ”μΆœν•˜λŠ” ν•¨μˆ˜
51
- json_loader = JSONLoader(docs)
52
- json_data = json_loader.load()
53
- # JSON 파일의 각 ν•­λͺ©μ„ λ¬Έμžμ—΄λ‘œ λ³€ν™˜ν•˜μ—¬ λ°˜ν™˜
54
- return [json.dumps(item) for item in json_data]
55
- else:
56
- st.warning("Unsupported file type for get_json_file")
57
 
58
 
59
  # λ¬Έμ„œλ“€μ„ μ²˜λ¦¬ν•˜μ—¬ ν…μŠ€νŠΈ 청크둜 λ‚˜λˆ„λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
60
  def get_text_chunks(documents):
61
  text_splitter = RecursiveCharacterTextSplitter(
62
- chunk_size=1000,
63
- chunk_overlap=200,
64
- length_function=len
65
  )
66
 
67
- # 각 λ¬Έμ„œμ˜ λ‚΄μš©μ„ λ¦¬μŠ€νŠΈμ— μΆ”κ°€
68
- texts = []
69
- for doc in documents:
70
- if hasattr(doc, 'page_content'):
71
- # λ¬Έμ„œ 객체인 κ²½μš°μ—λ§Œ μΆ”κ°€
72
- texts.append(doc.page_content)
73
- elif isinstance(doc, str):
74
- # λ¬Έμžμ—΄μΈ 경우 κ·ΈλŒ€λ‘œ μΆ”κ°€
75
- texts.append(doc)
76
-
77
- # λ‚˜λˆˆ 청크λ₯Ό λ°˜ν™˜
78
- return text_splitter.split_documents(texts)
79
-
80
 
81
 
82
  # ν…μŠ€νŠΈ μ²­ν¬λ“€λ‘œλΆ€ν„° 벑터 μŠ€ν† μ–΄λ₯Ό μƒμ„±ν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
@@ -90,30 +78,19 @@ def get_vectorstore(text_chunks):
90
 
91
 
92
  def get_conversation_chain(vectorstore):
93
- print(f"DEBUG: session_state.conversation before initialization: {st.session_state.conversation}")
94
-
95
- try:
96
- if st.session_state.conversation is None:
97
- gpt_model_name = 'gpt-3.5-turbo'
98
- llm = ChatOpenAI(model_name=gpt_model_name)
99
-
100
- # λŒ€ν™” 기둝을 μ €μž₯ν•˜κΈ° μœ„ν•œ λ©”λͺ¨λ¦¬λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
101
- memory = ConversationBufferMemory(
102
- memory_key='chat_history', return_messages=True)
103
- # λŒ€ν™” 검색 체인을 μƒμ„±ν•©λ‹ˆλ‹€.
104
- conversation_chain = ConversationalRetrievalChain.from_llm(
105
- llm=llm,
106
- retriever=vectorstore.as_retriever(),
107
- memory=memory
108
- )
109
- st.session_state.conversation = conversation_chain
110
-
111
- except Exception as e:
112
- print(f"Error during conversation initialization: {e}")
113
-
114
- print(f"DEBUG: session_state.conversation after initialization: {st.session_state.conversation}")
115
-
116
- return st.session_state.conversation if st.session_state.conversation else ConversationalRetrievalChain()
117
 
118
  # μ‚¬μš©μž μž…λ ₯을 μ²˜λ¦¬ν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
119
  def handle_userinput(user_question):
@@ -133,12 +110,13 @@ def handle_userinput(user_question):
133
 
134
  def main():
135
  load_dotenv()
136
- st.set_page_config(page_title="Chat with multiple Files :)",
137
  page_icon=":books:")
138
  st.write(css, unsafe_allow_html=True)
139
 
140
- if "conversation" not in st.session_state or st.session_state.conversation is None:
141
  st.session_state.conversation = None
 
142
  st.session_state.chat_history = None
143
 
144
  st.header("Chat with multiple Files :")
@@ -153,35 +131,36 @@ def main():
153
 
154
  st.subheader("Your documents")
155
  docs = st.file_uploader(
156
- "Upload your documents here and click on 'Process'", accept_multiple_files=True)
157
  if st.button("Process"):
158
- with st.spinner("Processing"):
159
- # λ¬Έμ„œμ—μ„œ μΆ”μΆœν•œ ν…μŠ€νŠΈλ₯Ό 담을 리슀트
160
- doc_list = []
161
-
162
- for file in docs:
163
- if file.type == 'text/plain':
164
- # .txt 파일의 경우
165
- doc_list.extend(get_text_file(file))
166
- elif file.type == 'text/csv':
167
- # .csv 파일의 경우
168
- doc_list.extend(get_csv_file(file))
169
- elif file.type == 'application/json':
170
- # .json 파일의 경우
171
- doc_list.extend(get_json_file(file))
172
- elif file.type in ['application/octet-stream', 'application/pdf']:
173
- # .pdf 파일의 경우
174
- doc_list.extend(get_pdf_text(file))
175
-
176
- # ν…μŠ€νŠΈ 청크둜 λ‚˜λˆ„κΈ°
177
- text_chunks = get_text_chunks(doc_list)
178
-
179
- # 벑터 μŠ€ν† μ–΄ 생성
180
- vectorstore = get_vectorstore(text_chunks)
181
-
182
- # λŒ€ν™” 체인 생성
183
- st.session_state.conversation = get_conversation_chain(vectorstore)
 
184
 
185
 
186
  if __name__ == '__main__':
187
- main()
 
29
  # μ•„λž˜ ν…μŠ€νŠΈ μΆ”μΆœ ν•¨μˆ˜λ₯Ό μž‘μ„±
30
 
31
  def get_text_file(docs):
32
+ text_list = []
33
+ for file in docs:
34
+ if file.type == 'text/plain':
35
+ # file is .txt
36
+ text_list.append(file.getvalue().decode('utf-8'))
37
+ return text_list
38
+
39
+
40
 
41
  def get_csv_file(docs):
42
+ csv_list = []
43
+ for file in docs:
44
+ if file.type == 'text/csv':
45
+ # file is .csv
46
+ csv_list.extend(csv.reader(file.getvalue().decode('utf-8').splitlines()))
47
+ return csv_list
 
 
48
 
49
  def get_json_file(docs):
50
+ json_list = []
51
+ for file in docs:
52
+ if file.type == 'application/json':
53
+ # file is .json
54
+ json_list.extend(json.load(file))
55
+ return json_list
 
 
56
 
57
 
58
  # λ¬Έμ„œλ“€μ„ μ²˜λ¦¬ν•˜μ—¬ ν…μŠ€νŠΈ 청크둜 λ‚˜λˆ„λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
59
  def get_text_chunks(documents):
60
  text_splitter = RecursiveCharacterTextSplitter(
61
+ chunk_size=1000, # 청크의 크기λ₯Ό μ§€μ •ν•©λ‹ˆλ‹€.
62
+ chunk_overlap=200, # 청크 μ‚¬μ΄μ˜ 쀑볡을 μ§€μ •ν•©λ‹ˆλ‹€.
63
+ length_function=len # ν…μŠ€νŠΈμ˜ 길이λ₯Ό μΈ‘μ •ν•˜λŠ” ν•¨μˆ˜λ₯Ό μ§€μ •ν•©λ‹ˆλ‹€.
64
  )
65
 
66
+ documents = text_splitter.split_documents(documents) # λ¬Έμ„œλ“€μ„ 청크둜 λ‚˜λˆ•λ‹ˆλ‹€
67
+ return documents # λ‚˜λˆˆ 청크λ₯Ό λ°˜ν™˜ν•©λ‹ˆλ‹€.
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  # ν…μŠ€νŠΈ μ²­ν¬λ“€λ‘œλΆ€ν„° 벑터 μŠ€ν† μ–΄λ₯Ό μƒμ„±ν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
 
78
 
79
 
80
  def get_conversation_chain(vectorstore):
81
+ gpt_model_name = 'gpt-3.5-turbo'
82
+ llm = ChatOpenAI(model_name = gpt_model_name) #gpt-3.5 λͺ¨λΈ λ‘œλ“œ
83
+
84
+ # λŒ€ν™” 기둝을 μ €μž₯ν•˜κΈ° μœ„ν•œ λ©”λͺ¨λ¦¬λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
85
+ memory = ConversationBufferMemory(
86
+ memory_key='chat_history', return_messages=True)
87
+ # λŒ€ν™” 검색 체인을 μƒμ„±ν•©λ‹ˆλ‹€.
88
+ conversation_chain = ConversationalRetrievalChain.from_llm(
89
+ llm=llm,
90
+ retriever=vectorstore.as_retriever(),
91
+ memory=memory
92
+ )
93
+ return conversation_chain
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # μ‚¬μš©μž μž…λ ₯을 μ²˜λ¦¬ν•˜λŠ” ν•¨μˆ˜μž…λ‹ˆλ‹€.
96
  def handle_userinput(user_question):
 
110
 
111
  def main():
112
  load_dotenv()
113
+ st.set_page_config(page_title="Chat with multiple Files",
114
  page_icon=":books:")
115
  st.write(css, unsafe_allow_html=True)
116
 
117
+ if "conversation" not in st.session_state:
118
  st.session_state.conversation = None
119
+ if "chat_history" not in st.session_state:
120
  st.session_state.chat_history = None
121
 
122
  st.header("Chat with multiple Files :")
 
131
 
132
  st.subheader("Your documents")
133
  docs = st.file_uploader(
134
+ "Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
135
  if st.button("Process"):
136
+ with st.spinner("Processing"):
137
+ # get pdf text
138
+ doc_list = []
139
+
140
+ for file in docs:
141
+ print('file - type : ', file.type)
142
+ if file.type == 'text/plain':
143
+ # file is .txt
144
+ doc_list.extend(get_text_file([file]))
145
+ elif file.type in ['application/octet-stream', 'application/pdf']:
146
+ # file is .pdf
147
+ doc_list.extend(get_pdf_text(file))
148
+ elif file.type == 'text/csv':
149
+ # file is .csv
150
+ doc_list.extend(get_csv_file([file]))
151
+ elif file.type == 'application/json':
152
+ # file is .json
153
+ doc_list.extend(get_json_file([file]))
154
+
155
+ # get the text chunks
156
+ text_chunks = get_text_chunks(doc_list)
157
+
158
+ # create vector store
159
+ vectorstore = get_vectorstore(text_chunks)
160
+
161
+ # create conversation chain
162
+ st.session_state.conversation = get_conversation_chain(vectorstore)
163
 
164
 
165
  if __name__ == '__main__':
166
+ main()