vtiyyal1 commited on
Commit
12cca3e
1 Parent(s): fb70715

Upload 10 files

Browse files

latest code to check citations fix

Files changed (10) hide show
  1. .gitattributes +35 -35
  2. README.md +13 -13
  3. app.py +80 -0
  4. feed_to_llm.py +101 -0
  5. feed_to_llm_v2.py +102 -0
  6. full_chain.py +46 -0
  7. get_articles.py +140 -0
  8. get_keywords.py +63 -0
  9. requirements.txt +15 -0
  10. rerank.py +273 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Tobacco Watcher Chat With Citations
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: https://tobaccowatcher.globaltobaccocontrol.org/
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Tobacco Watcher Chat
3
+ emoji: 🐨
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.25.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import gradio as gr
3
+ from full_chain import get_response
4
+ import os
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(levelname)s - %(message)s',
11
+ handlers=[
12
+ logging.FileHandler('app.log'),
13
+ logging.StreamHandler()
14
+ ]
15
+ )
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Initialize OpenAI client
19
+ try:
20
+ api_key = os.getenv("OPENAI_API_KEY")
21
+ if not api_key:
22
+ raise ValueError("OPENAI_API_KEY environment variable not set")
23
+ client = openai.OpenAI(api_key=api_key)
24
+ logger.info("OpenAI client initialized successfully")
25
+ except Exception as e:
26
+ logger.error(f"Failed to initialize OpenAI client: {str(e)}")
27
+ raise
28
+
29
+ def create_hyperlink(url, title, domain):
30
+ """Create HTML hyperlink with domain information."""
31
+ return f"<a href='{url}'>{title}</a> ({domain})"
32
+
33
+ def predict(message, history):
34
+ """Process user message and return response with source links."""
35
+ try:
36
+ logger.info(f"Processing new query: {message}")
37
+
38
+ # Get response from the chain
39
+ responder, links, titles, domains = get_response(message, rerank_type="crossencoder")
40
+ logger.info(f"Received response with {len(links)} sources")
41
+
42
+ # Create hyperlinks for sources
43
+ formatted_links = [create_hyperlink(link, title, domain)
44
+ for link, title, domain in zip(links, titles, domains)]
45
+
46
+ # Combine response with sources
47
+ out = responder + "\n" + "\n".join(formatted_links)
48
+
49
+ logger.info("Response generated successfully")
50
+ return out
51
+
52
+ except Exception as e:
53
+ error_msg = f"Error processing query: {str(e)}"
54
+ logger.error(error_msg)
55
+ return f"An error occurred while processing your request: {str(e)}"
56
+
57
+ # Define example queries
58
+ EXAMPLE_QUERIES = [
59
+ "How many Americans Smoke?",
60
+ "What are some measures taken by the Indian Government to reduce the smoking population?",
61
+ "Does smoking negatively affect my health?"
62
+ ]
63
+
64
+ # Initialize and launch Gradio interface
65
+ def main():
66
+ try:
67
+ interface = gr.ChatInterface(
68
+ predict,
69
+ examples=EXAMPLE_QUERIES,
70
+ title="Tobacco Information Assistant",
71
+ description="Ask questions about tobacco-related topics and get answers with reliable sources."
72
+ )
73
+ logger.info("Starting Gradio interface")
74
+ interface.launch()
75
+ except Exception as e:
76
+ logger.error(f"Failed to launch Gradio interface: {str(e)}")
77
+ raise
78
+
79
+ if __name__ == "__main__":
80
+ main()
feed_to_llm.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chat_models import ChatOpenAI
2
+
3
+ from langchain.schema import (
4
+ HumanMessage,
5
+ SystemMessage
6
+ )
7
+ import tiktoken
8
+ import re
9
+
10
+
11
+ def num_tokens_from_string(string: str, encoder) -> int:
12
+ num_tokens = len(encoder.encode(string))
13
+ return num_tokens
14
+
15
+
16
+ def feed_articles_to_gpt_with_links(information, question):
17
+ prompt = "The following pieces of information includes relevant articles. \nUse the following sentences to answer question. \nIf you don't know the answer, just say that you don't know, don't try to make up an answer. "
18
+ prompt += "Please state the number of the article used to answer the question after your response\n"
19
+ end_prompt = "\n----------------\n"
20
+ prompt += end_prompt
21
+ content = ""
22
+ seperator = "<<<<>>>>"
23
+
24
+ token_count = 0
25
+ encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
26
+ token_count += num_tokens_from_string(prompt, encoder)
27
+
28
+ articles = [contents for score, contents, uuids, titles, domains in information]
29
+ uuids = [uuids for score, contents, uuids, titles, domains in information]
30
+ domains = [domains for score, contents, uuids, titles, domains in information]
31
+
32
+ for i in range(len(articles)):
33
+ addition = "Article " + str(i + 1) + ": " + articles[i] + seperator
34
+ addition += articles[i] + seperator
35
+ token_count += num_tokens_from_string(addition, encoder)
36
+ if token_count > 3500:
37
+ print(i)
38
+ break
39
+
40
+ content += addition
41
+
42
+ prompt += content
43
+ llm = ChatOpenAI(temperature=0.0)
44
+ message = [
45
+ SystemMessage(content=prompt),
46
+ HumanMessage(content=question)
47
+ ]
48
+
49
+ response = llm(message)
50
+ print(response.content)
51
+ print("response length: ", len(response.content))
52
+
53
+ answer_found_prompt = "Please check if the following response found the answer. If yes, return 1 and if no, return 0. \n"
54
+ message = [
55
+ SystemMessage(content=answer_found_prompt),
56
+ HumanMessage(content=response.content)
57
+ ]
58
+ print(llm(message).content)
59
+ if llm(message).content == "0":
60
+ return "I could not find the answer.", [], [], []
61
+
62
+ # sources = "\n Sources: \n"
63
+ # for i in range(len(uuids)):
64
+ # link = "https://tobaccowatcher.globaltobaccocontrol.org/articles/" + uuids[i] + "/" + "\n"
65
+ # sources += link
66
+ # response.content += sources
67
+
68
+ lowercase_response = response.content.lower()
69
+ # remove parentheses
70
+ lowercase_response = re.sub('[()]', '', lowercase_response)
71
+ lowercase_split = lowercase_response.split()
72
+ used_article_num = []
73
+ for i in range(len(lowercase_split)):
74
+ if lowercase_split[i] == "article":
75
+ next_word = lowercase_split[i + 1]
76
+ # get rid of non-numenric characters
77
+ next_word = ''.join(c for c in next_word if c.isdigit())
78
+ print("Article number: ", next_word)
79
+ # append only if it is not present in the list
80
+ if next_word not in used_article_num:
81
+ used_article_num.append(next_word)
82
+
83
+ # if empty
84
+ print("Used article num: ", used_article_num)
85
+ if not used_article_num:
86
+ print("I could not find the answer. Reached")
87
+ return "I could not find the answer.", [], [], []
88
+
89
+ used_article_num = [int(num) - 1 for num in used_article_num]
90
+
91
+ links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
92
+ titles = [titles for score, contents, uuids, titles, domains in information]
93
+
94
+ links = [links[i] for i in used_article_num]
95
+ titles = [titles[i] for i in used_article_num]
96
+ domains = [domains[i] for i in used_article_num]
97
+
98
+ # get rid of substring that starts with (Article and ends with )
99
+ response_without_source = re.sub("""\(Article.*\)""", "", response.content)
100
+
101
+ return response_without_source, links, titles, domains
feed_to_llm_v2.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+
3
+ from langchain.schema import (
4
+ HumanMessage,
5
+ SystemMessage
6
+ )
7
+ import tiktoken
8
+ import re
9
+
10
+ from get_articles import save_solr_articles_full
11
+ from rerank import crossencoder_rerank_answer
12
+
13
+
14
+ def num_tokens_from_string(string: str, encoder) -> int:
15
+ num_tokens = len(encoder.encode(string))
16
+ return num_tokens
17
+
18
+
19
+ def feed_articles_to_gpt_with_links(information, question):
20
+ prompt = """
21
+ You are a Question Answering system specializing in tobacco-related topics. You have access to several curated articles, each numbered (e.g., Article 1, Article 2). These articles cover various aspects of tobacco use, health effects, legislation, and quitting resources.
22
+
23
+ When formulating your response, adhere to the following guidelines:
24
+
25
+ 1. Use information from the provided articles to directly answer the question. Explicitly reference the article(s) used in your response by stating the article number(s) (e.g., "According to Article 1, ..." or "Articles 2 and 3 mention that...").
26
+ 2. If the answer is not covered by any of the articles, clearly state that the information is unavailable. Do not guess or fabricate information.
27
+ 3. Avoid using ambiguous time references like 'recently' or 'last year.' Instead, use absolute terms based on the article's content (e.g., 'In 2021' or 'As per Article 2, published in 2020').
28
+ 4. Keep responses concise, accurate, and helpful while maintaining a professional tone.
29
+
30
+ Below is a list of articles you can reference. Each article is identified by its number and content:
31
+ """
32
+ end_prompt = "\n----------------\n"
33
+ prompt += end_prompt
34
+
35
+ content = ""
36
+ separator = "<<<<>>>>"
37
+ token_count = 0
38
+
39
+ # Encoder setup for token count tracking
40
+ encoder = tiktoken.encoding_for_model("gpt-3.5-turbo")
41
+ token_count += num_tokens_from_string(prompt, encoder)
42
+
43
+ # Add articles to the prompt
44
+ articles = [contents for score, contents, uuids, titles, domains in information]
45
+ uuids = [uuids for score, contents, uuids, titles, domains in information]
46
+ titles_list = [titles for score, contents, uuids, titles, domains in information]
47
+ domains_list = [domains for score, contents, uuids, titles, domains in information]
48
+
49
+ for i in range(len(articles)):
50
+ addition = f"Article {i + 1}: {articles[i]} {separator}"
51
+ token_count += num_tokens_from_string(addition, encoder)
52
+ if token_count > 3500:
53
+ break
54
+ content += addition
55
+
56
+ prompt += content
57
+ llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
58
+ message = [
59
+ SystemMessage(content=prompt),
60
+ HumanMessage(content=question)
61
+ ]
62
+
63
+ response = llm.invoke(message)
64
+ response_content = response.content # Access the content of the AIMessage
65
+ print("LLM Response Content:", response_content)
66
+
67
+ # Extract sources from the response content
68
+ matches = re.findall(r'\((Article \d+)\)', response_content)
69
+ if not matches:
70
+ print("No sources found in the response.")
71
+ return response_content, [], [], []
72
+
73
+ unique_matches = list(set(matches))
74
+ used_article_nums = [int(re.findall(r'\d+', match)[0]) - 1 for match in unique_matches]
75
+
76
+ # Create citation list
77
+ citations = []
78
+ for idx, num in enumerate(used_article_nums, start=1):
79
+ citation = f"{idx}. {titles_list[num]} ({domains_list[num]})"
80
+ citations.append(citation)
81
+
82
+ # Replace article numbers with citation numbers in response
83
+ for i, match in enumerate(unique_matches, start=1):
84
+ response_content = response_content.replace(match, f"[{i}]")
85
+
86
+ # Append citations to the response
87
+ response_with_citations = f"{response_content}\n\nReferences:\n" + "\n".join(citations)
88
+
89
+ # Prepare links with titles and domains
90
+ links = [f"https://tobaccowatcher.globaltobaccocontrol.org/articles/{uuid}/" for uuid in uuids]
91
+ hyperlinks = [f"<a href='{link}' target='_blank'>{titles_list[i]}</a> ({domains_list[i]})" for i, link in enumerate(links)]
92
+
93
+ return response_with_citations, hyperlinks, titles_list, domains_list
94
+
95
+
96
+ if __name__ == "__main__":
97
+ question = "How is United States fighting against tobacco addiction?"
98
+ rerank_type = "crossencoder"
99
+ llm_type = "chat"
100
+ csv_path = save_solr_articles_full(question, keyword_type="rake")
101
+ reranked_out = crossencoder_rerank_answer(csv_path, question)
102
+ feed_articles_to_gpt_with_links(reranked_out, question)
full_chain.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from get_keywords import get_keywords
4
+ from get_articles import save_solr_articles_full
5
+ from rerank import langchain_rerank_answer, langchain_with_sources, crossencoder_rerank_answer, \
6
+ crossencoder_rerank_sentencewise, crossencoder_rerank_sentencewise_articles, no_rerank
7
+ #from feed_to_llm import feed_articles_to_gpt_with_links
8
+ from feed_to_llm_v2 import feed_articles_to_gpt_with_links
9
+
10
+
11
+ def get_response(question, rerank_type="crossencoder", llm_type="chat"):
12
+
13
+ try:
14
+
15
+ csv_path = save_solr_articles_full(question, keyword_type="rake")
16
+
17
+ reranked_out = crossencoder_rerank_answer(csv_path, question)
18
+
19
+
20
+ # Prepare source metadata for citations
21
+ citations = [
22
+ {"title": article["title"], "url": article["url"], "source": article["source"]}
23
+ for article in reranked_out
24
+ ]
25
+
26
+
27
+ result = feed_articles_to_gpt_with_links(reranked_out, question, citations)
28
+
29
+
30
+ return result
31
+ except Exception as e:
32
+ return "", [], [], []
33
+
34
+
35
+
36
+
37
+
38
+ if __name__ == "__main__":
39
+ question = "How is United States fighting against tobacco addiction?"
40
+ rerank_type = "crossencoder"
41
+ llm_type = "chat"
42
+ response, links, titles, domains = get_response(question, rerank_type, llm_type)
43
+ print(response)
44
+ print(links)
45
+ print(titles)
46
+ print(domains)
get_articles.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pysolr import Solr
2
+ import os
3
+ import csv
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import torch
6
+
7
+ from get_keywords import get_keywords
8
+ import os
9
+
10
+ """
11
+ This function creates top 15 articles from Solr and saves them in a csv file
12
+ Input:
13
+ query: str
14
+ num_articles: int
15
+ keyword_type: str (openai, rake, or na)
16
+ Output: path to csv file
17
+ """
18
+ def save_solr_articles_full(query: str, num_articles=15, keyword_type="openai") -> str:
19
+ keywords = get_keywords(query, keyword_type)
20
+ if keyword_type == "na":
21
+ keywords = query
22
+ return save_solr_articles(keywords, num_articles)
23
+
24
+
25
+ """
26
+ Removes spaces and newlines from text
27
+ Input: text: str
28
+ Output: text: str
29
+ """
30
+ def remove_spaces_newlines(text: str) -> str:
31
+ text = text.replace('\n', ' ')
32
+ text = text.replace(' ', ' ')
33
+ return text
34
+
35
+
36
+ # truncates long articles to 1500 words
37
+ def truncate_article(text: str) -> str:
38
+ split = text.split()
39
+ if len(split) > 1500:
40
+ split = split[:1500]
41
+ text = ' '.join(split)
42
+ return text
43
+
44
+
45
+ """
46
+ Searches Solr for articles based on keywords and saves them in a csv file
47
+ Input:
48
+ keywords: str
49
+ num_articles: int
50
+ Output: path to csv file
51
+ Minor details:
52
+ Removes duplicate articles to start with.
53
+ Articles with dead urls are removed since those articles are often wierd.
54
+ Articles with titles that start with five starting words are removed. they are usually duplicates with minor changes.
55
+ If one of title, uuid, cleaned_content, url are missing the article is skipped.
56
+ """
57
+ def save_solr_articles(keywords: str, num_articles=15) -> str:
58
+ solr_key = os.getenv("SOLR_KEY")
59
+ SOLR_ARTICLES_URL = f"https://website:{solr_key}@solr.machines.globalhealthwatcher.org:8080/solr/articles/"
60
+ solr = Solr(SOLR_ARTICLES_URL, verify=False)
61
+
62
+ # No duplicates
63
+ fq = ['-dups:0']
64
+
65
+ query = f'text:({keywords})' + " AND " + "dead_url:(false)"
66
+
67
+ # Get top 2*num_articles articles and then remove misformed or duplicate articles
68
+ outputs = solr.search(query, fq=fq, sort="score desc", rows=num_articles * 2)
69
+
70
+ article_count = 0
71
+
72
+ save_path = os.path.join("data", "articles.csv")
73
+ if not os.path.exists(os.path.dirname(save_path)):
74
+ os.makedirs(os.path.dirname(save_path))
75
+
76
+ with open(save_path, 'w', newline='') as csvfile:
77
+ fieldnames = ['title', 'uuid', 'content', 'url', 'domain']
78
+ writer = csv.DictWriter(csvfile, fieldnames=fieldnames, quoting=csv.QUOTE_NONNUMERIC)
79
+ writer.writeheader()
80
+
81
+ title_five_words = set()
82
+
83
+ for d in outputs.docs:
84
+ if article_count == num_articles:
85
+ break
86
+
87
+ # skip if title returns a keyerror
88
+ if 'title' not in d or 'uuid' not in d or 'cleaned_content' not in d or 'url' not in d:
89
+ continue
90
+
91
+ title_cleaned = remove_spaces_newlines(d['title'])
92
+
93
+ split = title_cleaned.split()
94
+ # skip if title is a duplicate
95
+ if not len(split) < 5:
96
+ five_words = title_cleaned.split()[:5]
97
+ five_words = ' '.join(five_words)
98
+ if five_words in title_five_words:
99
+ continue
100
+ title_five_words.add(five_words)
101
+
102
+ article_count += 1
103
+
104
+ cleaned_content = remove_spaces_newlines(d['cleaned_content'])
105
+ cleaned_content = truncate_article(cleaned_content)
106
+
107
+ domain = ""
108
+ if 'domain' not in d:
109
+ domain = "Not Specified"
110
+ else:
111
+ domain = d['domain']
112
+ print(domain)
113
+
114
+ writer.writerow({'title': title_cleaned, 'uuid': d['uuid'], 'content': cleaned_content, 'url': d['url'],
115
+ 'domain': domain})
116
+ return save_path
117
+
118
+
119
+ def save_embedding_base_articles(query, article_embeddings, titles, contents, uuids, urls, num_articles=15):
120
+ bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
121
+ query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
122
+ hits = util.semantic_search(query_embedding, article_embeddings, top_k=15)
123
+ hits = hits[0]
124
+ corpus_ids = [item['corpus_id'] for item in hits]
125
+ r_contents = [contents[idx] for idx in corpus_ids]
126
+ r_titles = [titles[idx] for idx in corpus_ids]
127
+ r_uuids = [uuids[idx] for idx in corpus_ids]
128
+ r_urls = [urls[idx] for idx in corpus_ids]
129
+
130
+ save_path = os.path.join("data", "articles.csv")
131
+ if not os.path.exists(os.path.dirname(save_path)):
132
+ os.makedirs(os.path.dirname(save_path))
133
+
134
+ with open(save_path, 'w', newline='', encoding="utf-8") as csvfile:
135
+ fieldNames = ['title', 'uuid', 'content', 'url']
136
+ writer = csv.DictWriter(csvfile, fieldnames=fieldNames, quoting=csv.QUOTE_NONNUMERIC)
137
+ writer.writeheader()
138
+ for i in range(num_articles):
139
+ writer.writerow({'title': r_titles[i], 'uuid': r_uuids[i], 'content': r_contents[i], 'url': r_urls[i]})
140
+ return save_path
get_keywords.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_openai import ChatOpenAI
2
+ from langchain_core.messages import (
3
+ HumanMessage,
4
+ SystemMessage
5
+ )
6
+
7
+ from rake_nltk import Rake
8
+ import nltk
9
+ nltk.download('stopwords')
10
+ nltk.download('punkt')
11
+ """
12
+ This function takes in user query and returns keywords
13
+ Input:
14
+ user_query: str
15
+ keyword_type: str (openai, rake, or na)
16
+ If the keyword type is na, then user query is returned.
17
+ Output: keywords: str
18
+ """
19
+ def get_keywords(user_query: str, keyword_type: str) -> str:
20
+ if keyword_type == "openai":
21
+ return get_keywords_openai(user_query)
22
+ if keyword_type == "rake":
23
+ return get_keywords_rake(user_query)
24
+ else:
25
+ return user_query
26
+
27
+
28
+ """
29
+ This function takes user query and returns keywords using rake_nltk
30
+ rake_nltk actually returns keyphrases, not keywords. Since using keyphrases did not show improvement, we are using keywords
31
+ to match the output type of the other keyword functions.
32
+ Input:
33
+ user_query: str
34
+ Output: keywords: str
35
+ """
36
+ def get_keywords_rake(user_query: str) -> str:
37
+ r = Rake()
38
+ r.extract_keywords_from_text(user_query)
39
+ keyphrases = r.get_ranked_phrases()
40
+
41
+ # If we want to get keyphrases, return keyphrases but should do keywords
42
+ out = ""
43
+ for phrase in keyphrases:
44
+ out += phrase + " "
45
+ return out
46
+
47
+
48
+ """
49
+ This function takes user query and returns keywords using openai
50
+ Input:
51
+ user_query: str
52
+ Output: keywords: str
53
+ """
54
+ def get_keywords_openai(user_query: str) -> str:
55
+ llm = ChatOpenAI(temperature=0.0)
56
+ command = "return the keywords of the following query. response should be words separated by commas. "
57
+ message = [
58
+ SystemMessage(content=command),
59
+ HumanMessage(content=user_query)
60
+ ]
61
+ response = llm(message)
62
+ res = response.content.replace(",", "")
63
+ return res
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.25.0
2
+ langchain==0.1.14
3
+ langchain-core==0.1.40
4
+ langchain-openai==0.1.1
5
+ nltk==3.8.1
6
+ openai==1.16.2
7
+ pandas==2.2.1
8
+ pysolr==3.9.0
9
+ rake-nltk==1.0.6
10
+ sentence-transformers==2.2.2
11
+ tiktoken==0.5.2
12
+ torch==2.1.2
13
+ huggingface-hub==0.20.2
14
+ python-dotenv==1.0.1
15
+ docarray==0.40.0
rerank.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # reranks the top articles from a given csv file
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain.chains import RetrievalQA
4
+ from langchain_community.document_loaders.csv_loader import CSVLoader
5
+ from langchain_community.vectorstores import DocArrayInMemorySearch
6
+ from sentence_transformers import CrossEncoder
7
+ import pandas as pd
8
+ import time
9
+
10
+ """
11
+ This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
12
+ Input:
13
+ csv_path: str
14
+ question: str
15
+ top_n: int
16
+ Output:
17
+ response: str
18
+ links: list of str
19
+ titles: list of str
20
+
21
+ Other functions in this file does not send articles to LLM. This is an exception.
22
+ Created using langchain RAG functions. Deprecated.
23
+ Update: Use langchain_RAG instead.
24
+ """
25
+
26
+
27
+ def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
28
+ llm = ChatOpenAI(temperature=0.0)
29
+ loader = CSVLoader(csv_path, source_column="url")
30
+
31
+ index = VectorstoreIndexCreator(
32
+ vectorstore_cls=DocArrayInMemorySearch,
33
+ ).from_loaders([loader])
34
+
35
+ # prompt_template = """You are an a chatbot that answers tobacco related questions with source. Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
36
+ # {context}
37
+ # Question: {question}"""
38
+ # PROMPT = PromptTemplate(
39
+ # template=prompt_template, input_variables=["context", "question"]
40
+ # )
41
+ # chain_type_kwargs = {"prompt": PROMPT}
42
+
43
+ qa = RetrievalQA.from_chain_type(
44
+ llm=llm,
45
+ chain_type="stuff",
46
+ retriever=index.vectorstore.as_retriever(),
47
+ verbose=False,
48
+ return_source_documents=True,
49
+ # chain_type_kwargs=chain_type_kwargs,
50
+ # chain_type_kwargs = {
51
+ # "document_separator": "<<<<>>>>>"
52
+ # },
53
+ )
54
+
55
+ answer = qa({"query": question})
56
+ sources = answer['source_documents']
57
+ sources_out = [source.metadata['source'] for source in sources]
58
+
59
+ return answer['result'], sources_out
60
+
61
+
62
+ """
63
+ Langchain with sources.
64
+ This function is deprecated. Use langchain_RAG instead.
65
+ """
66
+
67
+
68
+ def langchain_with_sources(csv_path, question, top_n=4):
69
+ llm = ChatOpenAI(temperature=0.0)
70
+ loader = CSVLoader(csv_path, source_column="uuid")
71
+ index = VectorstoreIndexCreator(
72
+ vectorstore_cls=DocArrayInMemorySearch,
73
+ ).from_loaders([loader])
74
+
75
+ qa = RetrievalQAWithSourcesChain.from_chain_type(
76
+ llm=llm,
77
+ chain_type="stuff",
78
+ retriever=index.vectorstore.as_retriever(),
79
+ )
80
+ output = qa({"question": question}, return_only_outputs=True)
81
+ return output['answer'], output['sources']
82
+
83
+
84
+ """
85
+ Reranks the top articles using crossencoder.
86
+ Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
87
+ Input:
88
+ csv_path: str
89
+ question: str
90
+ top_n: int
91
+ Output:
92
+ out_values: list of [content, uuid, title]
93
+ """
94
+
95
+
96
+ # returns list of top n similar articles using crossencoder
97
+ def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
98
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
99
+ articles = pd.read_csv(csv_path)
100
+ contents = articles['content'].tolist()
101
+ uuids = articles['uuid'].tolist()
102
+ titles = articles['title'].tolist()
103
+
104
+ # biencoder retrieval does not have domain
105
+ if 'domain' not in articles:
106
+ domain = [""] * len(contents)
107
+ else:
108
+ domain = articles['domain'].tolist()
109
+
110
+ cross_inp = [[question, content] for content in contents]
111
+ cross_scores = cross_encoder.predict(cross_inp)
112
+ scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
113
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
114
+
115
+ out_values = scores_sentences[:top_n]
116
+
117
+ # if score is less than 0, truncate
118
+ for idx in range(len(out_values)):
119
+ if out_values[idx][0] < 0:
120
+ out_values = out_values[:idx]
121
+ if len(out_values) == 0:
122
+ out_values = scores_sentences[:1]
123
+
124
+ break
125
+ # print(out_values)
126
+ return out_values
127
+
128
+
129
+ def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list:
130
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
131
+ articles = pd.read_csv(csv_path)
132
+ contents = articles['content'].tolist()
133
+ uuids = articles['uuid'].tolist()
134
+ titles = articles['title'].tolist()
135
+
136
+ if 'domain' not in articles:
137
+ domain = [""] * len(contents)
138
+ else:
139
+ domain = articles['domain'].tolist()
140
+
141
+ sentences = []
142
+ new_uuids = []
143
+ new_titles = []
144
+ new_domains = []
145
+ for idx in range(len(contents)):
146
+ sents = sent_tokenize(contents[idx])
147
+ sentences.extend(sents)
148
+ new_uuids.extend([uuids[idx]] * len(sents))
149
+ new_titles.extend([titles[idx]] * len(sents))
150
+ new_domains.extend([domain[idx]] * len(sents))
151
+
152
+ cross_inp = [[question, sent] for sent in sentences]
153
+ cross_scores = cross_encoder.predict(cross_inp)
154
+ scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
155
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
156
+
157
+ out_values = scores_sentences[:top_n]
158
+
159
+ # if score is less than 0, truncate
160
+ for idx in range(len(out_values)):
161
+ if out_values[idx][0] < 0:
162
+ out_values = out_values[:idx]
163
+ if len(out_values) == 0:
164
+ out_values = scores_sentences[:1]
165
+
166
+ break
167
+
168
+ return out_values
169
+
170
+
171
+ def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2):
172
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
173
+ articles = pd.read_csv(csv_path)
174
+ contents = articles['content'].tolist()
175
+ uuids = articles['uuid'].tolist()
176
+ titles = articles['title'].tolist()
177
+
178
+ # embeddings do not have domain as column
179
+ if 'domain' not in articles:
180
+ domain = [""] * len(contents)
181
+ else:
182
+ domain = articles['domain'].tolist()
183
+
184
+ sentences = []
185
+ new_uuids = []
186
+ new_titles = []
187
+ new_domains = []
188
+
189
+ for idx in range(len(contents)):
190
+ sents = sent_tokenize(contents[idx])
191
+ sents_merged = []
192
+
193
+ # if the number of sentences is less than chunk size, merge and join
194
+ if len(sents) < chunk_size:
195
+ sents_merged.append(' '.join(sents))
196
+ else:
197
+ for i in range(0, len(sents) - chunk_size + 1):
198
+ sents_merged.append(' '.join(sents[i:i + chunk_size]))
199
+
200
+ sentences.extend(sents_merged)
201
+ new_uuids.extend([uuids[idx]] * len(sents_merged))
202
+ new_titles.extend([titles[idx]] * len(sents_merged))
203
+ new_domains.extend([domain[idx]] * len(sents_merged))
204
+
205
+ cross_inp = [[question, sent] for sent in sentences]
206
+ cross_scores = cross_encoder.predict(cross_inp)
207
+ scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
208
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
209
+
210
+ out_values = scores_sentences[:top_n]
211
+
212
+ for idx in range(len(out_values)):
213
+ if out_values[idx][0] < 0:
214
+ out_values = out_values[:idx]
215
+ if len(out_values) == 0:
216
+ out_values = scores_sentences[:1]
217
+
218
+ break
219
+
220
+ return out_values
221
+
222
+
223
+ def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4):
224
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
225
+ contents, uuids, titles, domain = load_articles(csv_path)
226
+
227
+ sentences = []
228
+ contents_elongated = []
229
+ new_uuids = []
230
+ new_titles = []
231
+ new_domains = []
232
+
233
+ for idx in range(len(contents)):
234
+ sents = sent_tokenize(contents[idx])
235
+ sentences.extend(sents)
236
+ new_uuids.extend([uuids[idx]] * len(sents))
237
+ contents_elongated.extend([contents[idx]] * len(sents))
238
+ new_titles.extend([titles[idx]] * len(sents))
239
+ new_domains.extend([domain[idx]] * len(sents))
240
+
241
+ cross_inp = [[question, sent] for sent in sentences]
242
+ cross_scores = cross_encoder.predict(cross_inp)
243
+ scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains))
244
+ scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
245
+
246
+ score_sentences_compressed = []
247
+ for item in scores_sentences:
248
+ if not score_sentences_compressed:
249
+ score_sentences_compressed.append(item)
250
+ else:
251
+ if item[2] not in [x[2] for x in score_sentences_compressed]:
252
+ score_sentences_compressed.append(item)
253
+
254
+ scores_sentences = score_sentences_compressed
255
+ return scores_sentences[:top_n]
256
+
257
+
258
+ def no_rerank(csv_path, question, top_n=4):
259
+ contents, uuids, titles, domains = load_articles(csv_path)
260
+ return list(zip(contents, uuids, titles, domains))[:top_n]
261
+
262
+
263
+ def load_articles(csv_path:str):
264
+ articles = pd.read_csv(csv_path)
265
+ contents = articles['content'].tolist()
266
+ uuids = articles['uuid'].tolist()
267
+ titles = articles['title'].tolist()
268
+ if 'domain' not in articles:
269
+ domain = [""] * len(contents)
270
+ else:
271
+ domain = articles['domain'].tolist()
272
+ return contents, uuids, titles, domain
273
+