mitulagr2 commited on
Commit
794ae55
·
1 Parent(s): 04305be
Files changed (2) hide show
  1. app/main.py +0 -11
  2. app/rag.py +4 -60
app/main.py CHANGED
@@ -23,29 +23,18 @@ app = FastAPI(middleware=middleware)
23
 
24
  files_dir = os.path.expanduser("~/wtp_be_files/")
25
  session_assistant = ChatPDF()
26
- # session_messages = []
27
-
28
- # async def stream_generator(agent_text_stream):
29
- # for text in agent_text_stream:
30
- # print(text)
31
- # yield text
32
 
33
  @app.get("/query")
34
  async def process_input(text: str):
35
  if text and len(text.strip()) > 0:
36
  text = text.strip()
37
  print("PRINTING STREAM")
38
- # agent_text_stream = session_assistant.ask(text)
39
- # print(agent_text_stream)
40
- # session_messages.append((text, True))
41
- # session_messages.append((agent_text, False))
42
  return StreamingResponse(session_assistant.ask(text), media_type='text/event-stream')
43
 
44
 
45
  @app.post("/upload")
46
  def upload(files: list[UploadFile]):
47
  session_assistant.clear()
48
- # session_messages = []
49
 
50
  try:
51
  os.makedirs(files_dir)
 
23
 
24
  files_dir = os.path.expanduser("~/wtp_be_files/")
25
  session_assistant = ChatPDF()
 
 
 
 
 
 
26
 
27
  @app.get("/query")
28
  async def process_input(text: str):
29
  if text and len(text.strip()) > 0:
30
  text = text.strip()
31
  print("PRINTING STREAM")
 
 
 
 
32
  return StreamingResponse(session_assistant.ask(text), media_type='text/event-stream')
33
 
34
 
35
  @app.post("/upload")
36
  def upload(files: list[UploadFile]):
37
  session_assistant.clear()
 
38
 
39
  try:
40
  os.makedirs(files_dir)
app/rag.py CHANGED
@@ -1,25 +1,18 @@
1
  import os
2
  import logging
3
-
4
  from llama_index.core import (
5
  SimpleDirectoryReader,
6
  VectorStoreIndex,
7
  StorageContext,
8
  Settings,
9
  get_response_synthesizer)
10
- # from llama_index.core.query_engine import RetrieverQueryEngine, TransformQueryEngine
11
  from llama_index.core.node_parser import SentenceSplitter
12
  from llama_index.core.schema import TextNode, MetadataMode
13
- # from llama_index.core.retrievers import VectorIndexRetriever
14
-
15
- # from llama_index.core.response_synthesizers import ResponseMode
16
- # from transformers import AutoTokenizer
17
  from llama_index.core.vector_stores import VectorStoreQuery
18
- from llama_index.vector_stores.qdrant import QdrantVectorStore
19
- from qdrant_client import QdrantClient
20
-
21
  from llama_index.llms.llama_cpp import LlamaCPP
22
  from llama_index.embeddings.fastembed import FastEmbedEmbedding
 
 
23
 
24
 
25
  QDRANT_API_URL = os.getenv('QDRANT_API_URL')
@@ -31,38 +24,12 @@ logger = logging.getLogger(__name__)
31
  class ChatPDF:
32
  query_engine = None
33
 
34
- model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q4_k_m.gguf"
35
- # model_url = "https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf"
36
- # model_url = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf"
37
-
38
- # def messages_to_prompt(messages):
39
- # prompt = ""
40
- # for message in messages:
41
- # if message.role == 'system':
42
- # prompt += f"<|system|>\n{message.content}</s>\n"
43
- # elif message.role == 'user':
44
- # prompt += f"<|user|>\n{message.content}</s>\n"
45
- # elif message.role == 'assistant':
46
- # prompt += f"<|assistant|>\n{message.content}</s>\n"
47
-
48
- # if not prompt.startswith("<|system|>\n"):
49
- # prompt = "<|system|>\n</s>\n" + prompt
50
-
51
- # prompt = prompt + "<|assistant|>\n"
52
-
53
- # return prompt
54
-
55
- # def completion_to_prompt(completion):
56
- # return f"<|system|>\n</s>\n<|user|>\n{completion}</s>\n<|assistant|>\n"
57
-
58
-
59
  def __init__(self):
60
  self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)
61
 
62
  logger.info("initializing the vector store related objects")
63
  # client = QdrantClient(host="localhost", port=6333)
64
  client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
65
- # client = QdrantClient(":memory:")
66
  self.vector_store = QdrantVectorStore(
67
  client=client,
68
  collection_name="rag_documents",
@@ -75,7 +42,7 @@ class ChatPDF:
75
  )
76
 
77
  llm = LlamaCPP(
78
- model_url=self.model_url,
79
  temperature=0.1,
80
  max_new_tokens=256,
81
  context_window=3900,
@@ -132,36 +99,13 @@ class ChatPDF:
132
  transformations=Settings.transformations,
133
  )
134
 
135
- # logger.info("configure retriever")
136
- # retriever = VectorIndexRetriever(
137
- # index=index,
138
- # similarity_top_k=6,
139
- # # vector_store_query_mode="hybrid"
140
- # )
141
-
142
- # logger.info("configure response synthesizer")
143
- # response_synthesizer = get_response_synthesizer(
144
- # # streaming=True,
145
- # response_mode=ResponseMode.COMPACT,
146
- # )
147
-
148
- # logger.info("assemble query engine")
149
- # self.query_engine = RetrieverQueryEngine(
150
- # retriever=retriever,
151
- # response_synthesizer=response_synthesizer,
152
- # )
153
-
154
  self.query_engine = index.as_query_engine(
155
  streaming=True,
156
- # similarity_top_k=6,
157
  )
158
 
159
  async def ask(self, query: str):
160
- if not self.query_engine:
161
- return "Please, add a PDF document first."
162
-
163
  logger.info("retrieving the response to the query")
164
- # response = self.query_engine.query(str_or_query_bundle=query)
165
  streaming_response = self.query_engine.query(query)
166
  print(streaming_response)
167
  # return streaming_response.response_gen
 
1
  import os
2
  import logging
 
3
  from llama_index.core import (
4
  SimpleDirectoryReader,
5
  VectorStoreIndex,
6
  StorageContext,
7
  Settings,
8
  get_response_synthesizer)
 
9
  from llama_index.core.node_parser import SentenceSplitter
10
  from llama_index.core.schema import TextNode, MetadataMode
 
 
 
 
11
  from llama_index.core.vector_stores import VectorStoreQuery
 
 
 
12
  from llama_index.llms.llama_cpp import LlamaCPP
13
  from llama_index.embeddings.fastembed import FastEmbedEmbedding
14
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
15
+ from qdrant_client import QdrantClient
16
 
17
 
18
  QDRANT_API_URL = os.getenv('QDRANT_API_URL')
 
24
  class ChatPDF:
25
  query_engine = None
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def __init__(self):
28
  self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=24)
29
 
30
  logger.info("initializing the vector store related objects")
31
  # client = QdrantClient(host="localhost", port=6333)
32
  client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
 
33
  self.vector_store = QdrantVectorStore(
34
  client=client,
35
  collection_name="rag_documents",
 
42
  )
43
 
44
  llm = LlamaCPP(
45
+ model_url="https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q4_k_m.gguf",
46
  temperature=0.1,
47
  max_new_tokens=256,
48
  context_window=3900,
 
99
  transformations=Settings.transformations,
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  self.query_engine = index.as_query_engine(
103
  streaming=True,
104
+ similarity_top_k=6,
105
  )
106
 
107
  async def ask(self, query: str):
 
 
 
108
  logger.info("retrieving the response to the query")
 
109
  streaming_response = self.query_engine.query(query)
110
  print(streaming_response)
111
  # return streaming_response.response_gen