jacob-braun-mn commited on
Commit
1d0a57d
1 Parent(s): e01ed2b

Add basic files. Model next.

Browse files
Files changed (4) hide show
  1. Dockerfile +11 -0
  2. helpers.py +103 -0
  3. main.py +56 -0
  4. requirements.txt +14 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
helpers.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+ from langchain.schema.document import Document
3
+ from langchain_community.vectorstores import FAISS
4
+ from typing import List, Dict, Union
5
+ import time
6
+
7
+
8
+ def store_doc(fullText, chunkLen, embeddingModel):
9
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
10
+ model_name="gpt-4o",
11
+ chunk_size=chunkLen,
12
+ chunk_overlap=0,
13
+ is_separator_regex=False,
14
+ )
15
+
16
+ splits = text_splitter.create_documents([fullText])
17
+ for i, split in enumerate(splits):
18
+ split.metadata['idx'] = i
19
+ split.metadata['highlight'] = False
20
+ split.metadata['similarity_rank'] = None
21
+ split.metadata['similarity'] = None
22
+
23
+ print(f"SPLIT LENGTH: {len(splits)}")
24
+ print("Embedding and storing documents in memory...")
25
+ if len(splits) > 100:
26
+ db = FAISS.from_documents(splits[:100], embeddingModel)
27
+ print(f"Docs 1 - 100 added to db. Total docs: {len(splits)}")
28
+ for i in range(100, len(splits), 100):
29
+ db.add_documents(splits[i:i+100])
30
+ print(f"Docs {i} - {i+100} added to db. Total docs: {len(splits)}")
31
+ time.sleep(2)
32
+ else:
33
+ db = FAISS.from_documents(splits, embeddingModel)
34
+ print(f"Docs 1 - {len(splits)} added to db. Total docs: {len(splits)}")
35
+
36
+ return db, splits
37
+
38
+ def transform_documents(splits: List[Document]) -> List[Dict[str, Union[int, str, None]]]:
39
+ result = []
40
+ combined_non_highlight_content = ""
41
+
42
+ for doc in splits:
43
+ highlight_idx = doc.metadata['similarity_rank'] if doc.metadata['highlight'] else None
44
+ if highlight_idx is not None:
45
+ if combined_non_highlight_content:
46
+ result.append(
47
+ {"highlight_idx": None, "page_content": combined_non_highlight_content.strip()}
48
+ )
49
+ combined_non_highlight_content = ""
50
+
51
+ result.append({"highlight_idx": highlight_idx, "page_content": doc.page_content.strip()})
52
+
53
+ if combined_non_highlight_content:
54
+ result.append(
55
+ {"highlight_idx": None, "page_content": combined_non_highlight_content.strip()}
56
+ )
57
+
58
+ return result
59
+
60
+ def get_relevant_docs(splits, userQuery, db, topK):
61
+ print("Searching for relevant documents...")
62
+ docs = db.similarity_search_with_relevance_scores(query=userQuery, k=topK)
63
+
64
+ highlights = []
65
+ for i, doc in enumerate(docs):
66
+ doc[0].metadata['similarity'] = doc[1]
67
+ doc[0].metadata['similarity_rank'] = i
68
+ doc[0].metadata['highlight'] = True
69
+ highlights.append({
70
+ 'page_content': doc[0].page_content,
71
+ 'similarity': doc[1],
72
+ 'similarity_rank': i,
73
+ 'highlight': True})
74
+
75
+ docviewer_text = transform_documents(splits)
76
+
77
+ return highlights, docviewer_text
78
+
79
+
80
+
81
+ def get_answer(highlights, question, model_pipe):
82
+ instructions = """
83
+ # INSTRUCTIONS\n\nYou are a helpful assistant that reviews relevant sections of clinical notes to answer user questions. First, review the text provided under the # Highlighted Sections in the user message to familiarize yourself with the content. Then, read the user question under the # Question section and think step by step through what information you need to answer their question. Next, review the provided Highlighted Sections for context again and find the relevant information for the user's question. Finally, synthesize that relevant information to answer the user's question. Keep your answer fully grounded in the facts from the Highlight Sections and reply at a 10th grade reading level. Keep your answer as concise as possible and only use relevant information from the provided documents. If the Highlighted Sections do not contain the necessary facts to answer the user's question, please respond with 'I didn't find the necessary information. Please try rephrasing your question or providing additional text.' Provide your summary in markdown format but do not use H1 (#) or H2 (##) headers.
84
+ """
85
+
86
+ documents = "# Highlighted Sections\n\n"
87
+ for i, highlight in enumerate(highlights):
88
+ documents += f"## Highlight {i+1}\n\n"
89
+ documents += highlight['page_content'] + "\n\n"
90
+
91
+ question = "# Question\n\n" + question + "\n\n"
92
+
93
+ reminder = "REMEMBER: Please keep your answer concise and fully grounded in the facts from the provided Highlighted Sections. Do not provide your own opinion or add information that is not supported by the Highlighted Sections. Provide your answer in markdown format but do not use H1 (#) or H2 (##) headers."
94
+
95
+ messages = [
96
+ {"role": "system", "content": instructions},
97
+ {"role": "user", "content": documents + question + reminder}
98
+ ]
99
+
100
+ response = model_pipe(messages, max_length=4096, temperature=0.7, num_return_sequences=1)
101
+
102
+ return response[0]['generated_text'][-1]['content']
103
+
main.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from dotenv import load_dotenv
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from transformers import pipeline
6
+ from helpers import store_doc, get_relevant_docs, get_answer
7
+ import os
8
+
9
+ load_dotenv(override=True)
10
+
11
+ app = FastAPI()
12
+
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"],
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ model_pipe = pipeline("text-generation", model="./models/Qwen2-1dot5B-Instruct")
22
+ embeddings = HuggingFaceEmbeddings(model_name="all-mpnet-base-v2")
23
+
24
+ @app.post("/search_document")
25
+ async def retrieve_record(textFile: UploadFile = File(...),
26
+ userQuery: str = Form(...),
27
+ chunkLen: int = Form(...),
28
+ aiSummaryEnabled: bool = Form(...),
29
+ topK: int = Form(...)):
30
+
31
+ fileContent = await textFile.read()
32
+ fullText = fileContent.decode("utf-8")
33
+
34
+ db, splits = store_doc(fullText,
35
+ chunkLen,
36
+ embeddings)
37
+
38
+ highlights, docviewer_text = get_relevant_docs(splits,
39
+ userQuery,
40
+ db,
41
+ topK)
42
+
43
+ if aiSummaryEnabled:
44
+ model_answer = get_answer(highlights,
45
+ userQuery,
46
+ model_pipe)
47
+ else:
48
+ model_answer = None
49
+
50
+ response_obj = {
51
+ 'highlights': highlights,
52
+ 'docviewerText': docviewer_text,
53
+ 'modelSummary': model_answer,
54
+ }
55
+
56
+ return response_obj
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ numpy
3
+ openai
4
+ tiktoken
5
+ fastapi
6
+ uvicorn
7
+ langchain
8
+ langchain-community
9
+ langchain-openai
10
+ langchain-huggingface
11
+ langchainhub
12
+ faiss-cpu
13
+ typing-extensions
14
+ sentence-transformers