Upload 10 files
Browse fileslatest code to check citations fix
- .gitattributes +35 -35
- README.md +13 -13
- app.py +80 -0
- feed_to_llm.py +101 -0
- feed_to_llm_v2.py +102 -0
- full_chain.py +46 -0
- get_articles.py +140 -0
- get_keywords.py +63 -0
- requirements.txt +15 -0
- rerank.py +273 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: Tobacco Watcher Chat
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Tobacco Watcher Chat
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.25.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import gradio as gr
|
3 |
+
from full_chain import get_response
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
|
7 |
+
# Configure logging
|
8 |
+
logging.basicConfig(
|
9 |
+
level=logging.INFO,
|
10 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
11 |
+
handlers=[
|
12 |
+
logging.FileHandler('app.log'),
|
13 |
+
logging.StreamHandler()
|
14 |
+
]
|
15 |
+
)
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
# Initialize OpenAI client
|
19 |
+
try:
|
20 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
21 |
+
if not api_key:
|
22 |
+
raise ValueError("OPENAI_API_KEY environment variable not set")
|
23 |
+
client = openai.OpenAI(api_key=api_key)
|
24 |
+
logger.info("OpenAI client initialized successfully")
|
25 |
+
except Exception as e:
|
26 |
+
logger.error(f"Failed to initialize OpenAI client: {str(e)}")
|
27 |
+
raise
|
28 |
+
|
29 |
+
def create_hyperlink(url, title, domain):
|
30 |
+
"""Create HTML hyperlink with domain information."""
|
31 |
+
return f"<a href='{url}'>{title}</a> ({domain})"
|
32 |
+
|
33 |
+
def predict(message, history):
|
34 |
+
"""Process user message and return response with source links."""
|
35 |
+
try:
|
36 |
+
logger.info(f"Processing new query: {message}")
|
37 |
+
|
38 |
+
# Get response from the chain
|
39 |
+
responder, links, titles, domains = get_response(message, rerank_type="crossencoder")
|
40 |
+
logger.info(f"Received response with {len(links)} sources")
|
41 |
+
|
42 |
+
# Create hyperlinks for sources
|
43 |
+
formatted_links = [create_hyperlink(link, title, domain)
|
44 |
+
for link, title, domain in zip(links, titles, domains)]
|
45 |
+
|
46 |
+
# Combine response with sources
|
47 |
+
out = responder + "\n" + "\n".join(formatted_links)
|
48 |
+
|
49 |
+
logger.info("Response generated successfully")
|
50 |
+
return out
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
error_msg = f"Error processing query: {str(e)}"
|
54 |
+
logger.error(error_msg)
|
55 |
+
return f"An error occurred while processing your request: {str(e)}"
|
56 |
+
|
57 |
+
# Define example queries
|
58 |
+
EXAMPLE_QUERIES = [
|
59 |
+
"How many Americans Smoke?",
|
60 |
+
"What are some measures taken by the Indian Government to reduce the smoking population?",
|
61 |
+
"Does smoking negatively affect my health?"
|
62 |
+
]
|
63 |
+
|
64 |
+
# Initialize and launch Gradio interface
|
65 |
+
def main():
|
66 |
+
try:
|
67 |
+
interface = gr.ChatInterface(
|
68 |
+
predict,
|
69 |
+
examples=EXAMPLE_QUERIES,
|
70 |
+
title="Tobacco Information Assistant",
|
71 |
+
description="Ask questions about tobacco-related topics and get answers with reliable sources."
|
72 |
+
)
|
73 |
+
logger.info("Starting Gradio interface")
|
74 |
+
interface.launch()
|
75 |
+
except Exception as e:
|
76 |
+
logger.error(f"Failed to launch Gradio interface: {str(e)}")
|
77 |
+
raise
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
main()
|
feed_to_llm.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chat_models import ChatOpenAI
|
2 |
+
|
3 |
+
from langchain.schema import (
|
4 |
+
HumanMessage,
|
5 |
+
SystemMessage
|
6 |
+
)
|
7 |
+
import tiktoken
|
8 |
+
import re
|
9 |
+
|
10 |
+
|
11 |
+
def num_tokens_from_string(string: str, encoder) -> int:
|
12 |
+
num_tokens = len(encoder.encode(string))
|
13 |
+
return num_tokens
|
14 |
+
|
15 |
+
|
16 |
+
def feed_articles_to_gpt_with_links(information, question):
|
17 |
+
prompt = "The following pieces of information includes relevant articles. \nUse the following sentences to answer question. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. "
|
18 |
+
prompt += "Please state the number of the article used to answer the question after your response\n"
|
19 |
+
end_prompt = "\n----------------\n"
|
20 |
+
prompt += end_prompt
|
21 |
+
content = ""
|
22 |
+
seperator = "<<<<>>>>"
|
23 |
+
|
24 |
+
token_count = 0
|
25 |
+
encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
26 |
+
token_count += num_tokens_from_string(prompt, encoder)
|
27 |
+
|
28 |
+
articles = [contents for score, contents, uuids, titles, domains in information]
|
29 |
+
uuids = [uuids for score, contents, uuids, titles, domains in information]
|
30 |
+
domains = [domains for score, contents, uuids, titles, domains in information]
|
31 |
+
|
32 |
+
for i in range(len(articles)):
|
33 |
+
addition = "Article " + str(i + 1) + ": " + articles[i] + seperator
|
34 |
+
addition += articles[i] + seperator
|
35 |
+
token_count += num_tokens_from_string(addition, encoder)
|
36 |
+
if token_count > 3500:
|
37 |
+
print(i)
|
38 |
+
break
|
39 |
+
|
40 |
+
content += addition
|
41 |
+
|
42 |
+
prompt += content
|
43 |
+
llm = ChatOpenAI(temperature=0.0)
|
44 |
+
message = [
|
45 |
+
SystemMessage(content=prompt),
|
46 |
+
HumanMessage(content=question)
|
47 |
+
]
|
48 |
+
|
49 |
+
response = llm(message)
|
50 |
+
print(response.content)
|
51 |
+
print("response length: ", len(response.content))
|
52 |
+
|
53 |
+
answer_found_prompt = "Please check if the following response found the answer. If yes, return 1 and if no, return 0. \n"
|
54 |
+
message = [
|
55 |
+
SystemMessage(content=answer_found_prompt),
|
56 |
+
HumanMessage(content=response.content)
|
57 |
+
]
|
58 |
+
print(llm(message).content)
|
59 |
+
if llm(message).content == "0":
|
60 |
+
return "I could not find the answer.", [], [], []
|
61 |
+
|
62 |
+
# sources = "\n Sources: \n"
|
63 |
+
# for i in range(len(uuids)):
|
64 |
+
# link = "https://tobaccowatcher.globaltobaccocontrol.org/articles/" + uuids[i] + "/" + "\n"
|
65 |
+
# sources += link
|
66 |
+
# response.content += sources
|
67 |
+
|
68 |
+
lowercase_response = response.content.lower()
|
69 |
+
# remove parentheses
|
70 |
+
lowercase_response = re.sub('[()]', '', lowercase_response)
|
71 |
+
lowercase_split = lowercase_response.split()
|
72 |
+
used_article_num = []
|
73 |
+
for i in range(len(lowercase_split)):
|
74 |
+
if lowercase_split[i] == "article":
|
75 |
+
next_word = lowercase_split[i + 1]
|
76 |
+
# get rid of non-numenric characters
|
77 |
+
next_word = ''.join(c for c in next_word if c.isdigit())
|
78 |
+
print("Article number: ", next_word)
|
79 |
+
# append only if it is not present in the list
|
80 |
+
if next_word not in used_article_num:
|
81 |
+
used_article_num.append(next_word)
|
82 |
+
|
83 |
+
# if empty
|
84 |
+
print("Used article num: ", used_article_num)
|
85 |
+
if not used_article_num:
|
86 |
+
print("I could not find the answer. Reached")
|
87 |
+
return "I could not find the answer.", [], [], []
|
88 |
+
|
89 |
+
used_article_num = [int(num) - 1 for num in used_article_num]
|
90 |
+
|
91 |
+
links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
|
92 |
+
titles = [titles for score, contents, uuids, titles, domains in information]
|
93 |
+
|
94 |
+
links = [links[i] for i in used_article_num]
|
95 |
+
titles = [titles[i] for i in used_article_num]
|
96 |
+
domains = [domains[i] for i in used_article_num]
|
97 |
+
|
98 |
+
# get rid of substring that starts with (Article and ends with )
|
99 |
+
response_without_source = re.sub("""\(Article.*\)""", "", response.content)
|
100 |
+
|
101 |
+
return response_without_source, links, titles, domains
|
feed_to_llm_v2.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import ChatOpenAI
|
2 |
+
|
3 |
+
from langchain.schema import (
|
4 |
+
HumanMessage,
|
5 |
+
SystemMessage
|
6 |
+
)
|
7 |
+
import tiktoken
|
8 |
+
import re
|
9 |
+
|
10 |
+
from get_articles import save_solr_articles_full
|
11 |
+
from rerank import crossencoder_rerank_answer
|
12 |
+
|
13 |
+
|
14 |
+
def num_tokens_from_string(string: str, encoder) -> int:
|
15 |
+
num_tokens = len(encoder.encode(string))
|
16 |
+
return num_tokens
|
17 |
+
|
18 |
+
|
19 |
+
def feed_articles_to_gpt_with_links(information, question):
|
20 |
+
prompt = """
|
21 |
+
You are a Question Answering system specializing in tobacco-related topics. You have access to several curated articles, each numbered (e.g., Article 1, Article 2). These articles cover various aspects of tobacco use, health effects, legislation, and quitting resources.
|
22 |
+
|
23 |
+
When formulating your response, adhere to the following guidelines:
|
24 |
+
|
25 |
+
1. Use information from the provided articles to directly answer the question. Explicitly reference the article(s) used in your response by stating the article number(s) (e.g., "According to Article 1, ..." or "Articles 2 and 3 mention that...").
|
26 |
+
2. If the answer is not covered by any of the articles, clearly state that the information is unavailable. Do not guess or fabricate information.
|
27 |
+
3. Avoid using ambiguous time references like 'recently' or 'last year.' Instead, use absolute terms based on the article's content (e.g., 'In 2021' or 'As per Article 2, published in 2020').
|
28 |
+
4. Keep responses concise, accurate, and helpful while maintaining a professional tone.
|
29 |
+
|
30 |
+
Below is a list of articles you can reference. Each article is identified by its number and content:
|
31 |
+
"""
|
32 |
+
end_prompt = "\n----------------\n"
|
33 |
+
prompt += end_prompt
|
34 |
+
|
35 |
+
content = ""
|
36 |
+
separator = "<<<<>>>>"
|
37 |
+
token_count = 0
|
38 |
+
|
39 |
+
# Encoder setup for token count tracking
|
40 |
+
encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
41 |
+
token_count += num_tokens_from_string(prompt, encoder)
|
42 |
+
|
43 |
+
# Add articles to the prompt
|
44 |
+
articles = [contents for score, contents, uuids, titles, domains in information]
|
45 |
+
uuids = [uuids for score, contents, uuids, titles, domains in information]
|
46 |
+
titles_list = [titles for score, contents, uuids, titles, domains in information]
|
47 |
+
domains_list = [domains for score, contents, uuids, titles, domains in information]
|
48 |
+
|
49 |
+
for i in range(len(articles)):
|
50 |
+
addition = f"Article {i + 1}: {articles[i]} {separator}"
|
51 |
+
token_count += num_tokens_from_string(addition, encoder)
|
52 |
+
if token_count > 3500:
|
53 |
+
break
|
54 |
+
content += addition
|
55 |
+
|
56 |
+
prompt += content
|
57 |
+
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
|
58 |
+
message = [
|
59 |
+
SystemMessage(content=prompt),
|
60 |
+
HumanMessage(content=question)
|
61 |
+
]
|
62 |
+
|
63 |
+
response = llm.invoke(message)
|
64 |
+
response_content = response.content # Access the content of the AIMessage
|
65 |
+
print("LLM Response Content:", response_content)
|
66 |
+
|
67 |
+
# Extract sources from the response content
|
68 |
+
matches = re.findall(r'\((Article \d+)\)', response_content)
|
69 |
+
if not matches:
|
70 |
+
print("No sources found in the response.")
|
71 |
+
return response_content, [], [], []
|
72 |
+
|
73 |
+
unique_matches = list(set(matches))
|
74 |
+
used_article_nums = [int(re.findall(r'\d+', match)[0]) - 1 for match in unique_matches]
|
75 |
+
|
76 |
+
# Create citation list
|
77 |
+
citations = []
|
78 |
+
for idx, num in enumerate(used_article_nums, start=1):
|
79 |
+
citation = f"{idx}. {titles_list[num]} ({domains_list[num]})"
|
80 |
+
citations.append(citation)
|
81 |
+
|
82 |
+
# Replace article numbers with citation numbers in response
|
83 |
+
for i, match in enumerate(unique_matches, start=1):
|
84 |
+
response_content = response_content.replace(match, f"[{i}]")
|
85 |
+
|
86 |
+
# Append citations to the response
|
87 |
+
response_with_citations = f"{response_content}\n\nReferences:\n" + "\n".join(citations)
|
88 |
+
|
89 |
+
# Prepare links with titles and domains
|
90 |
+
links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
|
91 |
+
hyperlinks = [f"<a href='{link}' target='_blank'>{titles_list[i]}</a> ({domains_list[i]})" for i, link in enumerate(links)]
|
92 |
+
|
93 |
+
return response_with_citations, hyperlinks, titles_list, domains_list
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
question = "How is United States fighting against tobacco addiction?"
|
98 |
+
rerank_type = "crossencoder"
|
99 |
+
llm_type = "chat"
|
100 |
+
csv_path = save_solr_articles_full(question, keyword_type="rake")
|
101 |
+
reranked_out = crossencoder_rerank_answer(csv_path, question)
|
102 |
+
feed_articles_to_gpt_with_links(reranked_out, question)
|
full_chain.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
from get_keywords import get_keywords
|
4 |
+
from get_articles import save_solr_articles_full
|
5 |
+
from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
|
6 |
+
crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
|
7 |
+
#from feed_to_llm import feed_articles_to_gpt_with_links
|
8 |
+
from feed_to_llm_v2 import feed_articles_to_gpt_with_links
|
9 |
+
|
10 |
+
|
11 |
+
def get_response(question, rerank_type="crossencoder", llm_type="chat"):
|
12 |
+
|
13 |
+
try:
|
14 |
+
|
15 |
+
csv_path = save_solr_articles_full(question, keyword_type="rake")
|
16 |
+
|
17 |
+
reranked_out = crossencoder_rerank_answer(csv_path, question)
|
18 |
+
|
19 |
+
|
20 |
+
# Prepare source metadata for citations
|
21 |
+
citations = [
|
22 |
+
{"title": article["title"], "url": article["url"], "source": article["source"]}
|
23 |
+
for article in reranked_out
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
result = feed_articles_to_gpt_with_links(reranked_out, question, citations)
|
28 |
+
|
29 |
+
|
30 |
+
return result
|
31 |
+
except Exception as e:
|
32 |
+
return "", [], [], []
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
question = "How is United States fighting against tobacco addiction?"
|
40 |
+
rerank_type = "crossencoder"
|
41 |
+
llm_type = "chat"
|
42 |
+
response, links, titles, domains = get_response(question, rerank_type, llm_type)
|
43 |
+
print(response)
|
44 |
+
print(links)
|
45 |
+
print(titles)
|
46 |
+
print(domains)
|
get_articles.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pysolr import Solr
|
2 |
+
import os
|
3 |
+
import csv
|
4 |
+
from sentence_transformers import SentenceTransformer, util
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from get_keywords import get_keywords
|
8 |
+
import os
|
9 |
+
|
10 |
+
"""
|
11 |
+
This function creates top 15 articles from Solr and saves them in a csv file
|
12 |
+
Input:
|
13 |
+
query: str
|
14 |
+
num_articles: int
|
15 |
+
keyword_type: str (openai, rake, or na)
|
16 |
+
Output: path to csv file
|
17 |
+
"""
|
18 |
+
def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str:
|
19 |
+
keywords = get_keywords(query, keyword_type)
|
20 |
+
if keyword_type == "na":
|
21 |
+
keywords = query
|
22 |
+
return save_solr_articles(keywords, num_articles)
|
23 |
+
|
24 |
+
|
25 |
+
"""
|
26 |
+
Removes spaces and newlines from text
|
27 |
+
Input: text: str
|
28 |
+
Output: text: str
|
29 |
+
"""
|
30 |
+
def remove_spaces_newlines(text: str) -> str:
|
31 |
+
text = text.replace('\n', ' ')
|
32 |
+
text = text.replace(' ', ' ')
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
# truncates long articles to 1500 words
|
37 |
+
def truncate_article(text: str) -> str:
|
38 |
+
split = text.split()
|
39 |
+
if len(split) > 1500:
|
40 |
+
split = split[:1500]
|
41 |
+
text = ' '.join(split)
|
42 |
+
return text
|
43 |
+
|
44 |
+
|
45 |
+
"""
|
46 |
+
Searches Solr for articles based on keywords and saves them in a csv file
|
47 |
+
Input:
|
48 |
+
keywords: str
|
49 |
+
num_articles: int
|
50 |
+
Output: path to csv file
|
51 |
+
Minor details:
|
52 |
+
Removes duplicate articles to start with.
|
53 |
+
Articles with dead urls are removed since those articles are often wierd.
|
54 |
+
Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
|
55 |
+
If one of title, uuid, cleaned_content, url are missing the article is skipped.
|
56 |
+
"""
|
57 |
+
def save_solr_articles(keywords: str, num_articles=15) -> str:
|
58 |
+
solr_key = os.getenv("SOLR_KEY")
|
59 |
+
SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
|
60 |
+
solr = Solr(SOLR_ARTICLES_URL, verify=False)
|
61 |
+
|
62 |
+
# No duplicates
|
63 |
+
fq = ['-dups:0']
|
64 |
+
|
65 |
+
query = f'text:({keywords})' + " AND " + "dead_url:(false)"
|
66 |
+
|
67 |
+
# Get top 2*num_articles articles and then remove misformed or duplicate articles
|
68 |
+
outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2)
|
69 |
+
|
70 |
+
article_count = 0
|
71 |
+
|
72 |
+
save_path = os.path.join("data", "articles.csv")
|
73 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
74 |
+
os.makedirs(os.path.dirname(save_path))
|
75 |
+
|
76 |
+
with open(save_path, 'w', newline='') as csvfile:
|
77 |
+
fieldnames = ['title', 'uuid', 'content', 'url', 'domain']
|
78 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
|
79 |
+
writer.writeheader()
|
80 |
+
|
81 |
+
title_five_words = set()
|
82 |
+
|
83 |
+
for d in outputs.docs:
|
84 |
+
if article_count == num_articles:
|
85 |
+
break
|
86 |
+
|
87 |
+
# skip if title returns a keyerror
|
88 |
+
if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
|
89 |
+
continue
|
90 |
+
|
91 |
+
title_cleaned = remove_spaces_newlines(d['title'])
|
92 |
+
|
93 |
+
split = title_cleaned.split()
|
94 |
+
# skip if title is a duplicate
|
95 |
+
if not len(split) < 5:
|
96 |
+
five_words = title_cleaned.split()[:5]
|
97 |
+
five_words = ' '.join(five_words)
|
98 |
+
if five_words in title_five_words:
|
99 |
+
continue
|
100 |
+
title_five_words.add(five_words)
|
101 |
+
|
102 |
+
article_count += 1
|
103 |
+
|
104 |
+
cleaned_content = remove_spaces_newlines(d['cleaned_content'])
|
105 |
+
cleaned_content = truncate_article(cleaned_content)
|
106 |
+
|
107 |
+
domain = ""
|
108 |
+
if 'domain' not in d:
|
109 |
+
domain = "Not Specified"
|
110 |
+
else:
|
111 |
+
domain = d['domain']
|
112 |
+
print(domain)
|
113 |
+
|
114 |
+
writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'],
|
115 |
+
'domain': domain})
|
116 |
+
return save_path
|
117 |
+
|
118 |
+
|
119 |
+
def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
|
120 |
+
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
|
121 |
+
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
122 |
+
hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
|
123 |
+
hits = hits[0]
|
124 |
+
corpus_ids = [item['corpus_id'] for item in hits]
|
125 |
+
r_contents = [contents[idx] for idx in corpus_ids]
|
126 |
+
r_titles = [titles[idx] for idx in corpus_ids]
|
127 |
+
r_uuids = [uuids[idx] for idx in corpus_ids]
|
128 |
+
r_urls = [urls[idx] for idx in corpus_ids]
|
129 |
+
|
130 |
+
save_path = os.path.join("data", "articles.csv")
|
131 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
132 |
+
os.makedirs(os.path.dirname(save_path))
|
133 |
+
|
134 |
+
with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
|
135 |
+
fieldNames = ['title', 'uuid', 'content', 'url']
|
136 |
+
writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
|
137 |
+
writer.writeheader()
|
138 |
+
for i in range(num_articles):
|
139 |
+
writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
|
140 |
+
return save_path
|
get_keywords.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import ChatOpenAI
|
2 |
+
from langchain_core.messages import (
|
3 |
+
HumanMessage,
|
4 |
+
SystemMessage
|
5 |
+
)
|
6 |
+
|
7 |
+
from rake_nltk import Rake
|
8 |
+
import nltk
|
9 |
+
nltk.download('stopwords')
|
10 |
+
nltk.download('punkt')
|
11 |
+
"""
|
12 |
+
This function takes in user query and returns keywords
|
13 |
+
Input:
|
14 |
+
user_query: str
|
15 |
+
keyword_type: str (openai, rake, or na)
|
16 |
+
If the keyword type is na, then user query is returned.
|
17 |
+
Output: keywords: str
|
18 |
+
"""
|
19 |
+
def get_keywords(user_query: str, keyword_type: str) -> str:
|
20 |
+
if keyword_type == "openai":
|
21 |
+
return get_keywords_openai(user_query)
|
22 |
+
if keyword_type == "rake":
|
23 |
+
return get_keywords_rake(user_query)
|
24 |
+
else:
|
25 |
+
return user_query
|
26 |
+
|
27 |
+
|
28 |
+
"""
|
29 |
+
This function takes user query and returns keywords using rake_nltk
|
30 |
+
rake_nltk actually returns keyphrases, not keywords. Since using keyphrases did not show improvement, we are using keywords
|
31 |
+
to match the output type of the other keyword functions.
|
32 |
+
Input:
|
33 |
+
user_query: str
|
34 |
+
Output: keywords: str
|
35 |
+
"""
|
36 |
+
def get_keywords_rake(user_query: str) -> str:
|
37 |
+
r = Rake()
|
38 |
+
r.extract_keywords_from_text(user_query)
|
39 |
+
keyphrases = r.get_ranked_phrases()
|
40 |
+
|
41 |
+
# If we want to get keyphrases, return keyphrases but should do keywords
|
42 |
+
out = ""
|
43 |
+
for phrase in keyphrases:
|
44 |
+
out += phrase + " "
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
"""
|
49 |
+
This function takes user query and returns keywords using openai
|
50 |
+
Input:
|
51 |
+
user_query: str
|
52 |
+
Output: keywords: str
|
53 |
+
"""
|
54 |
+
def get_keywords_openai(user_query: str) -> str:
|
55 |
+
llm = ChatOpenAI(temperature=0.0)
|
56 |
+
command = "return the keywords of the following query. response should be words separated by commas. "
|
57 |
+
message = [
|
58 |
+
SystemMessage(content=command),
|
59 |
+
HumanMessage(content=user_query)
|
60 |
+
]
|
61 |
+
response = llm(message)
|
62 |
+
res = response.content.replace(",", "")
|
63 |
+
return res
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.25.0
|
2 |
+
langchain==0.1.14
|
3 |
+
langchain-core==0.1.40
|
4 |
+
langchain-openai==0.1.1
|
5 |
+
nltk==3.8.1
|
6 |
+
openai==1.16.2
|
7 |
+
pandas==2.2.1
|
8 |
+
pysolr==3.9.0
|
9 |
+
rake-nltk==1.0.6
|
10 |
+
sentence-transformers==2.2.2
|
11 |
+
tiktoken==0.5.2
|
12 |
+
torch==2.1.2
|
13 |
+
huggingface-hub==0.20.2
|
14 |
+
python-dotenv==1.0.1
|
15 |
+
docarray==0.40.0
|
rerank.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# reranks the top articles from a given csv file
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain.chains import RetrievalQA
|
4 |
+
from langchain_community.document_loaders.csv_loader import CSVLoader
|
5 |
+
from langchain_community.vectorstores import DocArrayInMemorySearch
|
6 |
+
from sentence_transformers import CrossEncoder
|
7 |
+
import pandas as pd
|
8 |
+
import time
|
9 |
+
|
10 |
+
"""
|
11 |
+
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
|
12 |
+
Input:
|
13 |
+
csv_path: str
|
14 |
+
question: str
|
15 |
+
top_n: int
|
16 |
+
Output:
|
17 |
+
response: str
|
18 |
+
links: list of str
|
19 |
+
titles: list of str
|
20 |
+
|
21 |
+
Other functions in this file does not send articles to LLM. This is an exception.
|
22 |
+
Created using langchain RAG functions. Deprecated.
|
23 |
+
Update: Use langchain_RAG instead.
|
24 |
+
"""
|
25 |
+
|
26 |
+
|
27 |
+
def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
|
28 |
+
llm = ChatOpenAI(temperature=0.0)
|
29 |
+
loader = CSVLoader(csv_path, source_column="url")
|
30 |
+
|
31 |
+
index = VectorstoreIndexCreator(
|
32 |
+
vectorstore_cls=DocArrayInMemorySearch,
|
33 |
+
).from_loaders([loader])
|
34 |
+
|
35 |
+
# prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
36 |
+
# {context}
|
37 |
+
# Question: {question}"""
|
38 |
+
# PROMPT = PromptTemplate(
|
39 |
+
# template=prompt_template, input_variables=["context", "question"]
|
40 |
+
# )
|
41 |
+
# chain_type_kwargs = {"prompt": PROMPT}
|
42 |
+
|
43 |
+
qa = RetrievalQA.from_chain_type(
|
44 |
+
llm=llm,
|
45 |
+
chain_type="stuff",
|
46 |
+
retriever=index.vectorstore.as_retriever(),
|
47 |
+
verbose=False,
|
48 |
+
return_source_documents=True,
|
49 |
+
# chain_type_kwargs=chain_type_kwargs,
|
50 |
+
# chain_type_kwargs = {
|
51 |
+
# "document_separator": "<<<<>>>>>"
|
52 |
+
# },
|
53 |
+
)
|
54 |
+
|
55 |
+
answer = qa({"query": question})
|
56 |
+
sources = answer['source_documents']
|
57 |
+
sources_out = [source.metadata['source'] for source in sources]
|
58 |
+
|
59 |
+
return answer['result'], sources_out
|
60 |
+
|
61 |
+
|
62 |
+
"""
|
63 |
+
Langchain with sources.
|
64 |
+
This function is deprecated. Use langchain_RAG instead.
|
65 |
+
"""
|
66 |
+
|
67 |
+
|
68 |
+
def langchain_with_sources(csv_path, question, top_n=4):
|
69 |
+
llm = ChatOpenAI(temperature=0.0)
|
70 |
+
loader = CSVLoader(csv_path, source_column="uuid")
|
71 |
+
index = VectorstoreIndexCreator(
|
72 |
+
vectorstore_cls=DocArrayInMemorySearch,
|
73 |
+
).from_loaders([loader])
|
74 |
+
|
75 |
+
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
76 |
+
llm=llm,
|
77 |
+
chain_type="stuff",
|
78 |
+
retriever=index.vectorstore.as_retriever(),
|
79 |
+
)
|
80 |
+
output = qa({"question": question}, return_only_outputs=True)
|
81 |
+
return output['answer'], output['sources']
|
82 |
+
|
83 |
+
|
84 |
+
"""
|
85 |
+
Reranks the top articles using crossencoder.
|
86 |
+
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
|
87 |
+
Input:
|
88 |
+
csv_path: str
|
89 |
+
question: str
|
90 |
+
top_n: int
|
91 |
+
Output:
|
92 |
+
out_values: list of [content, uuid, title]
|
93 |
+
"""
|
94 |
+
|
95 |
+
|
96 |
+
# returns list of top n similar articles using crossencoder
|
97 |
+
def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
|
98 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
99 |
+
articles = pd.read_csv(csv_path)
|
100 |
+
contents = articles['content'].tolist()
|
101 |
+
uuids = articles['uuid'].tolist()
|
102 |
+
titles = articles['title'].tolist()
|
103 |
+
|
104 |
+
# biencoder retrieval does not have domain
|
105 |
+
if 'domain' not in articles:
|
106 |
+
domain = [""] * len(contents)
|
107 |
+
else:
|
108 |
+
domain = articles['domain'].tolist()
|
109 |
+
|
110 |
+
cross_inp = [[question, content] for content in contents]
|
111 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
112 |
+
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
|
113 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
114 |
+
|
115 |
+
out_values = scores_sentences[:top_n]
|
116 |
+
|
117 |
+
# if score is less than 0, truncate
|
118 |
+
for idx in range(len(out_values)):
|
119 |
+
if out_values[idx][0] < 0:
|
120 |
+
out_values = out_values[:idx]
|
121 |
+
if len(out_values) == 0:
|
122 |
+
out_values = scores_sentences[:1]
|
123 |
+
|
124 |
+
break
|
125 |
+
# print(out_values)
|
126 |
+
return out_values
|
127 |
+
|
128 |
+
|
129 |
+
def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list:
|
130 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
131 |
+
articles = pd.read_csv(csv_path)
|
132 |
+
contents = articles['content'].tolist()
|
133 |
+
uuids = articles['uuid'].tolist()
|
134 |
+
titles = articles['title'].tolist()
|
135 |
+
|
136 |
+
if 'domain' not in articles:
|
137 |
+
domain = [""] * len(contents)
|
138 |
+
else:
|
139 |
+
domain = articles['domain'].tolist()
|
140 |
+
|
141 |
+
sentences = []
|
142 |
+
new_uuids = []
|
143 |
+
new_titles = []
|
144 |
+
new_domains = []
|
145 |
+
for idx in range(len(contents)):
|
146 |
+
sents = sent_tokenize(contents[idx])
|
147 |
+
sentences.extend(sents)
|
148 |
+
new_uuids.extend([uuids[idx]] * len(sents))
|
149 |
+
new_titles.extend([titles[idx]] * len(sents))
|
150 |
+
new_domains.extend([domain[idx]] * len(sents))
|
151 |
+
|
152 |
+
cross_inp = [[question, sent] for sent in sentences]
|
153 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
154 |
+
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
155 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
156 |
+
|
157 |
+
out_values = scores_sentences[:top_n]
|
158 |
+
|
159 |
+
# if score is less than 0, truncate
|
160 |
+
for idx in range(len(out_values)):
|
161 |
+
if out_values[idx][0] < 0:
|
162 |
+
out_values = out_values[:idx]
|
163 |
+
if len(out_values) == 0:
|
164 |
+
out_values = scores_sentences[:1]
|
165 |
+
|
166 |
+
break
|
167 |
+
|
168 |
+
return out_values
|
169 |
+
|
170 |
+
|
171 |
+
def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2):
|
172 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
173 |
+
articles = pd.read_csv(csv_path)
|
174 |
+
contents = articles['content'].tolist()
|
175 |
+
uuids = articles['uuid'].tolist()
|
176 |
+
titles = articles['title'].tolist()
|
177 |
+
|
178 |
+
# embeddings do not have domain as column
|
179 |
+
if 'domain' not in articles:
|
180 |
+
domain = [""] * len(contents)
|
181 |
+
else:
|
182 |
+
domain = articles['domain'].tolist()
|
183 |
+
|
184 |
+
sentences = []
|
185 |
+
new_uuids = []
|
186 |
+
new_titles = []
|
187 |
+
new_domains = []
|
188 |
+
|
189 |
+
for idx in range(len(contents)):
|
190 |
+
sents = sent_tokenize(contents[idx])
|
191 |
+
sents_merged = []
|
192 |
+
|
193 |
+
# if the number of sentences is less than chunk size, merge and join
|
194 |
+
if len(sents) < chunk_size:
|
195 |
+
sents_merged.append(' '.join(sents))
|
196 |
+
else:
|
197 |
+
for i in range(0, len(sents) - chunk_size + 1):
|
198 |
+
sents_merged.append(' '.join(sents[i:i + chunk_size]))
|
199 |
+
|
200 |
+
sentences.extend(sents_merged)
|
201 |
+
new_uuids.extend([uuids[idx]] * len(sents_merged))
|
202 |
+
new_titles.extend([titles[idx]] * len(sents_merged))
|
203 |
+
new_domains.extend([domain[idx]] * len(sents_merged))
|
204 |
+
|
205 |
+
cross_inp = [[question, sent] for sent in sentences]
|
206 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
207 |
+
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
208 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
209 |
+
|
210 |
+
out_values = scores_sentences[:top_n]
|
211 |
+
|
212 |
+
for idx in range(len(out_values)):
|
213 |
+
if out_values[idx][0] < 0:
|
214 |
+
out_values = out_values[:idx]
|
215 |
+
if len(out_values) == 0:
|
216 |
+
out_values = scores_sentences[:1]
|
217 |
+
|
218 |
+
break
|
219 |
+
|
220 |
+
return out_values
|
221 |
+
|
222 |
+
|
223 |
+
def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4):
|
224 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
225 |
+
contents, uuids, titles, domain = load_articles(csv_path)
|
226 |
+
|
227 |
+
sentences = []
|
228 |
+
contents_elongated = []
|
229 |
+
new_uuids = []
|
230 |
+
new_titles = []
|
231 |
+
new_domains = []
|
232 |
+
|
233 |
+
for idx in range(len(contents)):
|
234 |
+
sents = sent_tokenize(contents[idx])
|
235 |
+
sentences.extend(sents)
|
236 |
+
new_uuids.extend([uuids[idx]] * len(sents))
|
237 |
+
contents_elongated.extend([contents[idx]] * len(sents))
|
238 |
+
new_titles.extend([titles[idx]] * len(sents))
|
239 |
+
new_domains.extend([domain[idx]] * len(sents))
|
240 |
+
|
241 |
+
cross_inp = [[question, sent] for sent in sentences]
|
242 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
243 |
+
scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains))
|
244 |
+
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
245 |
+
|
246 |
+
score_sentences_compressed = []
|
247 |
+
for item in scores_sentences:
|
248 |
+
if not score_sentences_compressed:
|
249 |
+
score_sentences_compressed.append(item)
|
250 |
+
else:
|
251 |
+
if item[2] not in [x[2] for x in score_sentences_compressed]:
|
252 |
+
score_sentences_compressed.append(item)
|
253 |
+
|
254 |
+
scores_sentences = score_sentences_compressed
|
255 |
+
return scores_sentences[:top_n]
|
256 |
+
|
257 |
+
|
258 |
+
def no_rerank(csv_path, question, top_n=4):
|
259 |
+
contents, uuids, titles, domains = load_articles(csv_path)
|
260 |
+
return list(zip(contents, uuids, titles, domains))[:top_n]
|
261 |
+
|
262 |
+
|
263 |
+
def load_articles(csv_path:str):
|
264 |
+
articles = pd.read_csv(csv_path)
|
265 |
+
contents = articles['content'].tolist()
|
266 |
+
uuids = articles['uuid'].tolist()
|
267 |
+
titles = articles['title'].tolist()
|
268 |
+
if 'domain' not in articles:
|
269 |
+
domain = [""] * len(contents)
|
270 |
+
else:
|
271 |
+
domain = articles['domain'].tolist()
|
272 |
+
return contents, uuids, titles, domain
|
273 |
+
|