|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain.chains import RetrievalQA
|
|
from langchain_community.document_loaders.csv_loader import CSVLoader
|
|
from langchain_community.vectorstores import DocArrayInMemorySearch
|
|
from sentence_transformers import CrossEncoder
|
|
import pandas as pd
|
|
import time
|
|
|
|
"""
|
|
This function rerank top articles (15 -> 4) from a given csv, then sends to LLM
|
|
Input:
|
|
csv_path: str
|
|
question: str
|
|
top_n: int
|
|
Output:
|
|
response: str
|
|
links: list of str
|
|
titles: list of str
|
|
|
|
Other functions in this file does not send articles to LLM. This is an exception.
|
|
Created using langchain RAG functions. Deprecated.
|
|
Update: Use langchain_RAG instead.
|
|
"""
|
|
|
|
|
|
def langchain_rerank_answer(csv_path, question, source='url', top_n=4):
|
|
llm = ChatOpenAI(temperature=0.0)
|
|
loader = CSVLoader(csv_path, source_column="url")
|
|
|
|
index = VectorstoreIndexCreator(
|
|
vectorstore_cls=DocArrayInMemorySearch,
|
|
).from_loaders([loader])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qa = RetrievalQA.from_chain_type(
|
|
llm=llm,
|
|
chain_type="stuff",
|
|
retriever=index.vectorstore.as_retriever(),
|
|
verbose=False,
|
|
return_source_documents=True,
|
|
|
|
|
|
|
|
|
|
)
|
|
|
|
answer = qa({"query": question})
|
|
sources = answer['source_documents']
|
|
sources_out = [source.metadata['source'] for source in sources]
|
|
|
|
return answer['result'], sources_out
|
|
|
|
|
|
"""
|
|
Langchain with sources.
|
|
This function is deprecated. Use langchain_RAG instead.
|
|
"""
|
|
|
|
|
|
def langchain_with_sources(csv_path, question, top_n=4):
|
|
llm = ChatOpenAI(temperature=0.0)
|
|
loader = CSVLoader(csv_path, source_column="uuid")
|
|
index = VectorstoreIndexCreator(
|
|
vectorstore_cls=DocArrayInMemorySearch,
|
|
).from_loaders([loader])
|
|
|
|
qa = RetrievalQAWithSourcesChain.from_chain_type(
|
|
llm=llm,
|
|
chain_type="stuff",
|
|
retriever=index.vectorstore.as_retriever(),
|
|
)
|
|
output = qa({"question": question}, return_only_outputs=True)
|
|
return output['answer'], output['sources']
|
|
|
|
|
|
"""
|
|
Reranks the top articles using crossencoder.
|
|
Uses cross-encoder/ms-marco-MiniLM-L-6-v2 for embedding / reranking.
|
|
Input:
|
|
csv_path: str
|
|
question: str
|
|
top_n: int
|
|
Output:
|
|
out_values: list of [content, uuid, title]
|
|
"""
|
|
|
|
|
|
|
|
def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
articles = pd.read_csv(csv_path)
|
|
contents = articles['content'].tolist()
|
|
uuids = articles['uuid'].tolist()
|
|
titles = articles['title'].tolist()
|
|
|
|
|
|
if 'domain' not in articles:
|
|
domain = [""] * len(contents)
|
|
else:
|
|
domain = articles['domain'].tolist()
|
|
|
|
cross_inp = [[question, content] for content in contents]
|
|
cross_scores = cross_encoder.predict(cross_inp)
|
|
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
|
|
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
|
|
|
out_values = scores_sentences[:top_n]
|
|
|
|
|
|
for idx in range(len(out_values)):
|
|
if out_values[idx][0] < 0:
|
|
out_values = out_values[:idx]
|
|
if len(out_values) == 0:
|
|
out_values = scores_sentences[:1]
|
|
|
|
break
|
|
|
|
return out_values
|
|
|
|
|
|
def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) -> list:
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
articles = pd.read_csv(csv_path)
|
|
contents = articles['content'].tolist()
|
|
uuids = articles['uuid'].tolist()
|
|
titles = articles['title'].tolist()
|
|
|
|
if 'domain' not in articles:
|
|
domain = [""] * len(contents)
|
|
else:
|
|
domain = articles['domain'].tolist()
|
|
|
|
sentences = []
|
|
new_uuids = []
|
|
new_titles = []
|
|
new_domains = []
|
|
for idx in range(len(contents)):
|
|
sents = sent_tokenize(contents[idx])
|
|
sentences.extend(sents)
|
|
new_uuids.extend([uuids[idx]] * len(sents))
|
|
new_titles.extend([titles[idx]] * len(sents))
|
|
new_domains.extend([domain[idx]] * len(sents))
|
|
|
|
cross_inp = [[question, sent] for sent in sentences]
|
|
cross_scores = cross_encoder.predict(cross_inp)
|
|
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
|
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
|
|
|
out_values = scores_sentences[:top_n]
|
|
|
|
|
|
for idx in range(len(out_values)):
|
|
if out_values[idx][0] < 0:
|
|
out_values = out_values[:idx]
|
|
if len(out_values) == 0:
|
|
out_values = scores_sentences[:1]
|
|
|
|
break
|
|
|
|
return out_values
|
|
|
|
|
|
def crossencoder_rerank_sentencewise_sentence_chunks(csv_path, question, top_n=10, chunk_size=2):
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
articles = pd.read_csv(csv_path)
|
|
contents = articles['content'].tolist()
|
|
uuids = articles['uuid'].tolist()
|
|
titles = articles['title'].tolist()
|
|
|
|
|
|
if 'domain' not in articles:
|
|
domain = [""] * len(contents)
|
|
else:
|
|
domain = articles['domain'].tolist()
|
|
|
|
sentences = []
|
|
new_uuids = []
|
|
new_titles = []
|
|
new_domains = []
|
|
|
|
for idx in range(len(contents)):
|
|
sents = sent_tokenize(contents[idx])
|
|
sents_merged = []
|
|
|
|
|
|
if len(sents) < chunk_size:
|
|
sents_merged.append(' '.join(sents))
|
|
else:
|
|
for i in range(0, len(sents) - chunk_size + 1):
|
|
sents_merged.append(' '.join(sents[i:i + chunk_size]))
|
|
|
|
sentences.extend(sents_merged)
|
|
new_uuids.extend([uuids[idx]] * len(sents_merged))
|
|
new_titles.extend([titles[idx]] * len(sents_merged))
|
|
new_domains.extend([domain[idx]] * len(sents_merged))
|
|
|
|
cross_inp = [[question, sent] for sent in sentences]
|
|
cross_scores = cross_encoder.predict(cross_inp)
|
|
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
|
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
|
|
|
out_values = scores_sentences[:top_n]
|
|
|
|
for idx in range(len(out_values)):
|
|
if out_values[idx][0] < 0:
|
|
out_values = out_values[:idx]
|
|
if len(out_values) == 0:
|
|
out_values = scores_sentences[:1]
|
|
|
|
break
|
|
|
|
return out_values
|
|
|
|
|
|
def crossencoder_rerank_sentencewise_articles(csv_path, question, top_n=4):
|
|
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
|
contents, uuids, titles, domain = load_articles(csv_path)
|
|
|
|
sentences = []
|
|
contents_elongated = []
|
|
new_uuids = []
|
|
new_titles = []
|
|
new_domains = []
|
|
|
|
for idx in range(len(contents)):
|
|
sents = sent_tokenize(contents[idx])
|
|
sentences.extend(sents)
|
|
new_uuids.extend([uuids[idx]] * len(sents))
|
|
contents_elongated.extend([contents[idx]] * len(sents))
|
|
new_titles.extend([titles[idx]] * len(sents))
|
|
new_domains.extend([domain[idx]] * len(sents))
|
|
|
|
cross_inp = [[question, sent] for sent in sentences]
|
|
cross_scores = cross_encoder.predict(cross_inp)
|
|
scores_sentences = list(zip(cross_scores, contents_elongated, new_uuids, new_titles, new_domains))
|
|
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
|
|
|
score_sentences_compressed = []
|
|
for item in scores_sentences:
|
|
if not score_sentences_compressed:
|
|
score_sentences_compressed.append(item)
|
|
else:
|
|
if item[2] not in [x[2] for x in score_sentences_compressed]:
|
|
score_sentences_compressed.append(item)
|
|
|
|
scores_sentences = score_sentences_compressed
|
|
return scores_sentences[:top_n]
|
|
|
|
|
|
def no_rerank(csv_path, question, top_n=4):
|
|
contents, uuids, titles, domains = load_articles(csv_path)
|
|
return list(zip(contents, uuids, titles, domains))[:top_n]
|
|
|
|
|
|
def load_articles(csv_path:str):
|
|
articles = pd.read_csv(csv_path)
|
|
contents = articles['content'].tolist()
|
|
uuids = articles['uuid'].tolist()
|
|
titles = articles['title'].tolist()
|
|
if 'domain' not in articles:
|
|
domain = [""] * len(contents)
|
|
else:
|
|
domain = articles['domain'].tolist()
|
|
return contents, uuids, titles, domain
|
|
|
|
|