sepidnes commited on
Commit
64cdb49
·
verified ·
1 Parent(s): 902a480

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -139
app.py DELETED
@@ -1,139 +0,0 @@
1
- import os
2
- from typing import List
3
- from chainlit.types import AskFileResponse
4
- from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader
5
- from aimakerspace.openai_utils.prompts import (
6
- UserRolePrompt,
7
- SystemRolePrompt,
8
- AssistantRolePrompt,
9
- )
10
- from aimakerspace.openai_utils.embedding import EmbeddingModel
11
- from aimakerspace.vectordatabase import VectorDatabase
12
- from aimakerspace.openai_utils.chatmodel import ChatOpenAI
13
- import chainlit as cl
14
-
15
- system_template = """\
16
- Use the following context to answer a users question. If you cannot find the answer in the context, say you don't know the answer."""
17
- system_role_prompt = SystemRolePrompt(system_template)
18
-
19
- user_prompt_template = """\
20
- Context:
21
- {context}
22
-
23
- Question:
24
- {question}
25
- """
26
- user_role_prompt = UserRolePrompt(user_prompt_template)
27
-
28
- class RetrievalAugmentedQAPipeline:
29
- def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
30
- self.llm = llm
31
- self.vector_db_retriever = vector_db_retriever
32
-
33
- async def arun_pipeline(self, user_query: str):
34
- context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
35
-
36
- context_prompt = ""
37
- for context in context_list:
38
- context_prompt += context[0] + "\n"
39
-
40
- formatted_system_prompt = system_role_prompt.create_message()
41
-
42
- formatted_user_prompt = user_role_prompt.create_message(question=user_query, context=context_prompt)
43
-
44
- async def generate_response():
45
- async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
46
- yield chunk
47
-
48
- return {"response": generate_response(), "context": context_list}
49
-
50
- text_splitter = CharacterTextSplitter()
51
-
52
-
53
- def process_file(file: AskFileResponse):
54
- import tempfile
55
- import shutil
56
-
57
- print(f"Processing file: {file.name}")
58
-
59
- # Create a temporary file with the correct extension
60
- suffix = f".{file.name.split('.')[-1]}"
61
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
62
- # Copy the uploaded file content to the temporary file
63
- shutil.copyfile(file.path, temp_file.name)
64
- print(f"Created temporary file at: {temp_file.name}")
65
-
66
- # Create appropriate loader
67
- if file.name.lower().endswith('.pdf'):
68
- loader = PDFLoader(temp_file.name)
69
- else:
70
- loader = TextFileLoader(temp_file.name)
71
-
72
- try:
73
- # Load and process the documents
74
- documents = loader.load_documents()
75
- texts = text_splitter.split_texts(documents)
76
- return texts
77
- finally:
78
- # Clean up the temporary file
79
- try:
80
- os.unlink(temp_file.name)
81
- except Exception as e:
82
- print(f"Error cleaning up temporary file: {e}")
83
-
84
-
85
- @cl.on_chat_start
86
- async def on_chat_start():
87
- files = None
88
-
89
- # Wait for the user to upload a file
90
- while files == None:
91
- files = await cl.AskFileMessage(
92
- content="Please upload a Text or PDF file to begin!",
93
- accept=["text/plain", "application/pdf"],
94
- max_size_mb=2,
95
- timeout=180,
96
- ).send()
97
-
98
- file = files[0]
99
-
100
- msg = cl.Message(
101
- content=f"Processing `{file.name}`..."
102
- )
103
- await msg.send()
104
-
105
- # load the file
106
- texts = process_file(file)
107
-
108
- print(f"Processing {len(texts)} text chunks")
109
-
110
- # Create a dict vector store
111
- vector_db = VectorDatabase()
112
- vector_db = await vector_db.abuild_from_list(texts)
113
-
114
- chat_openai = ChatOpenAI()
115
-
116
- # Create a chain
117
- retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
118
- vector_db_retriever=vector_db,
119
- llm=chat_openai
120
- )
121
-
122
- # Let the user know that the system is ready
123
- msg.content = f"Processing `{file.name}` done. You can now ask questions!"
124
- await msg.update()
125
-
126
- cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
127
-
128
-
129
- @cl.on_message
130
- async def main(message):
131
- chain = cl.user_session.get("chain")
132
-
133
- msg = cl.Message(content="")
134
- result = await chain.arun_pipeline(message.content)
135
-
136
- async for stream_resp in result["response"]:
137
- await msg.stream_token(stream_resp)
138
-
139
- await msg.send()