Spaces:
Sleeping
Sleeping
import concurrent.futures | |
from collections import defaultdict | |
import pandas as pd | |
import numpy as np | |
import json | |
import os | |
import pickle | |
import pprint | |
from io import StringIO | |
import textwrap | |
import time | |
import re | |
from openai import OpenAI | |
openai_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) | |
import octoai | |
octoai_client = octoai.client.Client(token=os.getenv('OCTOML_KEY')) | |
from pinecone import Pinecone, ServerlessSpec | |
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY')) | |
pc_256 = pc.Index('prorata-postman-ds-256-v2') | |
pc_128 = pc.Index('prorata-postman-ds-128-v2') | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
sentence_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=128, | |
chunk_overlap=0, | |
separators=["\n\n", "\n", "."], | |
keep_separator=False | |
) | |
from functools import cache | |
def get_embedding(text, model="text-embedding-3-small"): | |
text = text.replace("\n", " ") | |
return openai_client.embeddings.create(input = [text], model=model).data[0].embedding | |
def get_embedding_l(text_l, model="text-embedding-3-small"): | |
text_l = [text.replace("\n", " ") for text in text_l] | |
res = openai_client.embeddings.create(input=text_l, model=model) | |
embeds = [record.embedding for record in res.data] | |
return embeds | |
def do_character_replacements(text): | |
# TODO: double quotes need to be removed properly since they interfere with parsing of JSON responses | |
return text.translate(str.maketrans({'“': '\'\'', '”': '\'\'', '"': '\'\'', "’": "'"})) | |
def parse_json_string(content): | |
fixed_content = content | |
for _ in range(20): | |
try: | |
result = json.loads(fixed_content) | |
break | |
except Exception as e: | |
print(e) | |
if "Expecting ',' delimiter" in str(e): | |
# "Expecting , delimiter: line x column y (char d)" | |
idx = int(re.findall(r'\(char (\d+)\)', str(e))[0]) | |
fixed_content = fixed_content[:idx] + ',' + fixed_content[idx:] | |
print(fixed_content) | |
print() | |
elif "Expecting property name enclosed in double quotes" in str(e): | |
# Expecting property name enclosed in double quotes: line x column y (char d) | |
idx = int(re.findall(r'\(char (\d+)\)', str(e))[0]) | |
fixed_content = fixed_content[:idx-1] + '}' + fixed_content[idx:] | |
print(fixed_content) | |
print() | |
else: | |
raise ValueError(str(e)) | |
return result | |
# prompt_af_template_llama3 = "Please breakdown the following paragraph into independent and atomic facts. Format your response as a signle JSON object, a list of facts:\n\n{}" | |
prompt_af_template_llama3 = "Please breakdown the following paragraph into independent and atomic facts. Format your response in JSON as a list of 'fact' objects:\n\n{}" | |
# prompt_tf_template = "Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination and rationale for the determination. \n\nContext: {}\n\nQuestion: {} Is this claim true or false?" | |
prompt_tf_template = ( | |
"Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination" | |
" and rationale for the determination. \n\nContext: ```{}```\n\nQuestion: <{}>" | |
" According to the context, is the previous claim (in between <> braces) true or false?" | |
) | |
# prompt_tf_template = ( | |
# "Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination" | |
# " and rationale for the determination. \n\nContext: ```{}```\n\nQuestion: <{}>" | |
# " Does the context explicitly support the previous claim (in between <> braces), true or false?" | |
# ) | |
# prompt_tf_template = ( | |
# "Given the context below, anwer the question that follows. Please format your answer in JSON with a yes or no determination" | |
# " and rationale for the determination. \n\nContext: ```{}```\n\nQuestion: <{}>" | |
# " Does the context explicitly support or strongly suggest the previous claim (in between <> braces), yes or no?" | |
# ) | |
def get_atoms_list(answer, file=None): | |
prompt_af = prompt_af_template_llama3.format(answer) | |
response, atoms_l = None, [] | |
for _ in range(5): | |
try: | |
# response = octoai_client.chat.completions.create( | |
# model="meta-llama-3-70b-instruct", | |
# messages=[ | |
# {"role": "system", "content": "You are a helpful assistant."}, | |
# {"role": "user", "content": prompt_af} | |
# ], | |
# # response_format={"type": "json_object"}, | |
# max_tokens=512, | |
# presence_penalty=0, | |
# temperature=0.1, | |
# top_p=0.9, | |
# ) | |
response = octoai_client.chat.completions.create( | |
model="meta-llama-3-70b-instruct", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt_af} | |
], | |
# response_format={"type": "json_object"}, | |
max_tokens=512, | |
presence_penalty=0, | |
temperature=0.1, | |
top_p=0.9, | |
) | |
content = response.choices[0].message.content | |
idx1 = content.find('```') | |
idx2 = idx1+3 + content[idx1+3:].find('```') | |
# atoms_l = json.loads(content[idx1+3:idx2]) | |
atoms_l = parse_json_string(content[idx1+3:idx2]) | |
atoms_l = [a['fact'] for a in atoms_l] | |
break | |
except Exception as error: | |
print(error, file=file) | |
print(response, file=file) | |
print(content[idx1+3:idx2], file=file) | |
time.sleep(2) | |
return atoms_l | |
def get_topk_matches(atom, k=5, pc_index=pc_256): | |
embed_atom = get_embedding(atom) | |
res = pc_index.query(vector=embed_atom, top_k=k, include_metadata=True) | |
return res['matches'] | |
def get_match_atom_entailment_determination(_match, atom, file=None, DEBUG=0): | |
if 'chunk_text_l' in _match: | |
chunk_text = '\n\n'.join(_match['chunk_text_l']) | |
else: | |
chunk_text = _match['metadata']['text'] | |
print(f"Determining entailment for url={_match['metadata']['url']} and atom {atom}...") | |
chunk_text = do_character_replacements(chunk_text) | |
prompt_tf = prompt_tf_template.format(chunk_text, atom) | |
if DEBUG > 0: | |
print(prompt_tf) | |
response = None | |
chunk_determination = {} | |
chunk_determination['chunk_id'] = _match['id'] | |
chunk_determination['true'] = False | |
chunk_determination['rationale'] = '' | |
for _ in range(5): | |
try: | |
response = octoai_client.chat.completions.create( | |
model="meta-llama-3-70b-instruct", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt_tf} | |
], | |
# response_format={"type": "json_object"}, | |
max_tokens=512, | |
# presence_penalty=0, | |
temperature=0.1, | |
# top_p=0.9, | |
) | |
content = response.choices[0].message.content | |
idx1 = content.find('{') | |
idx2 = content.find('}') | |
chunk_determination.update(json.loads(content[idx1:idx2+1])) | |
_det_lower = chunk_determination['determination'].lower() | |
chunk_determination['true'] = "true" in _det_lower or "yes" in _det_lower | |
break | |
except Exception as error: | |
print(error, file=file) | |
print(prompt_tf, file=file) | |
print(response, file=file) | |
time.sleep(2) | |
print(f"Finished entailment for url={_match['metadata']['url']} and atom {atom}.") | |
return chunk_determination | |
def get_atom_support(atom, file=None): | |
topk_matches = get_topk_matches(atom) | |
atom_support = {} | |
for _match in topk_matches: | |
chunk_determination = atom_support.get(_match['metadata']['url'], {}) | |
if not chunk_determination or not chunk_determination['true']: | |
atom_support[_match['metadata']['url']] = get_match_atom_entailment_determination(_match, atom, file=file) | |
return atom_support | |
def get_atom_support_list(atoms_l, file=None): | |
return [get_atom_support(a, file=file) for a in atoms_l] | |
def credit_atom_support_list(atom_support_l): | |
num_atoms = len(atom_support_l) | |
credit_d = defaultdict(float) | |
for atom_support in atom_support_l: | |
atom_support_size = 0.0 | |
for url_determination_d in atom_support.values(): | |
if url_determination_d['true']: | |
atom_support_size += 1.0 | |
for url, url_determination_d in atom_support.items(): | |
if url_determination_d['true']: | |
credit_d[url] += 1.0 / atom_support_size | |
for url in credit_d.keys(): | |
credit_d[url] = credit_d[url] / num_atoms | |
return credit_d | |
def print_atom_support(atom_support, prefix='', print_chunks=False, file=None): | |
for url, aggmatch_determination in atom_support.items(): | |
print(f"{prefix}{url}:", file=file) | |
print(f"{prefix} Determination: {'YES' if aggmatch_determination['true'] else 'NO'}", file=file) | |
# print(f"{prefix} Rationale: {aggmatch_determination['rationale']}", file=file) | |
print(textwrap.fill(f"{prefix} Rationale: {aggmatch_determination['rationale']}", initial_indent='', subsequent_indent=f'{prefix} ', width=100), file=file) | |
if print_chunks: | |
# n_chunks = len(aggmatch_determination['offset_l']) | |
# for j in range(n_chunks): | |
# cid, coffset = aggmatch_determination['id_l'][j], aggmatch_determination['offset_l'][j] | |
# cend_offset = aggmatch_determination['offset_l'][j+1] if j < n_chunks-1 else len(aggmatch_determination['chunks_text']) | |
# ctext = aggmatch_determination['chunks_text'][coffset:cend_offset] | |
# print(textwrap.fill(f"{prefix} Chunk {cid}: {ctext}\n", initial_indent='', subsequent_indent=f'{prefix} ', width=100), file=file) | |
if 'chunk_text_l' in aggmatch_determination: | |
for cid, ctext in zip(aggmatch_determination['id_l'], aggmatch_determination['chunk_text_l']): | |
print(textwrap.fill(f"{prefix} Chunk {cid}: {ctext}\n", initial_indent='', subsequent_indent=f'{prefix} ', width=100), file=file) | |
def print_credit_dist(credit_dist, prefix='', url_to_id=None, file=None): | |
credit_l = [(url, w) for url, w in credit_dist.items()] | |
credit_l = sorted(credit_l, key=lambda x: x[1], reverse=True) | |
for url, w in credit_l: | |
if url_to_id is None: | |
print(f"{prefix}{url}: {100*w:.2f}%", file=file) | |
else: | |
print(f"{prefix}{url_to_id[url]} {url}: {100*w:.2f}%", file=file) | |
# concurrent LLM calls | |
def get_atom_topk_matches_l_concurrent(atoms_l, max_workers=4): | |
atom_topkmatches_l = [] | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [] | |
for atom in atoms_l: | |
futures.append(executor.submit(get_topk_matches, atom)) | |
for f in futures: | |
r = f.result() | |
atom_topkmatches_l.append(r) | |
return atom_topkmatches_l | |
def aggregate_atom_topkmatches_l(atom_topkmatches_l): | |
atom_url_to_aggmacth_maps_l = [] | |
for atom_topkmatches in atom_topkmatches_l: | |
atom_url_to_aggmatch_map = {} | |
atom_url_to_aggmacth_maps_l.append(atom_url_to_aggmatch_map) | |
for _match in atom_topkmatches: | |
if _match['metadata']['url'] not in atom_url_to_aggmatch_map: | |
match_copy = {} | |
match_copy['id'] = _match['id'] | |
match_copy['id_l'] = [_match['id']] | |
match_copy['score'] = _match['score'] | |
match_copy['values'] = _match['values'] | |
# TODO: change to list of chunks and then append at query time | |
match_copy['metadata'] = {} | |
match_copy['metadata']['url'] = _match['metadata']['url'] | |
match_copy['metadata']['chunk'] = _match['metadata']['chunk'] | |
match_copy['chunk_text_l'] = [_match['metadata']['text']] | |
match_copy['metadata']['title'] = _match['metadata']['title'] | |
atom_url_to_aggmatch_map[_match['metadata']['url']] = match_copy | |
else: | |
prev_match = atom_url_to_aggmatch_map[_match['metadata']['url']] | |
prev_match['id_l'].append(_match['id']) | |
match_copy['chunk_text_l'].append(_match['metadata']['text']) | |
atomidx_w_single_url_aggmatch_l = [] | |
for idx, atom_url_to_aggmatch_map in enumerate(atom_url_to_aggmacth_maps_l): | |
for agg_match in atom_url_to_aggmatch_map.values(): | |
atomidx_w_single_url_aggmatch_l.append((idx, agg_match)) | |
return atomidx_w_single_url_aggmatch_l | |
def get_atmom_support_l_from_atomidx_w_single_url_aggmatch_l_concurrent(atoms_l, atomidx_w_single_url_aggmatch_l, max_workers=4): | |
atom_support_l = [{} for _ in atoms_l] | |
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
futures = [] | |
for atomidx_w_single_url_aggmatch in atomidx_w_single_url_aggmatch_l: | |
futures.append(executor.submit( | |
get_match_atom_entailment_determination, | |
atomidx_w_single_url_aggmatch[1], | |
atoms_l[atomidx_w_single_url_aggmatch[0]], | |
) | |
) | |
for f, atomidx_w_single_url_aggmatch in zip(futures, atomidx_w_single_url_aggmatch_l): | |
atom_support = atom_support_l[atomidx_w_single_url_aggmatch[0]] | |
aggmatch = atomidx_w_single_url_aggmatch[1] | |
aggmatch_determination = f.result() | |
aggmatch_determination['id_l'] = aggmatch['id_l'] | |
aggmatch_determination['chunk_text_l'] = aggmatch['chunk_text_l'] | |
atom_support[aggmatch['metadata']['url']] = aggmatch_determination | |
return atom_support_l | |
style_str = """ | |
<style> | |
.section-title { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
font-size: 2.5em; | |
font-weight: bolder; | |
padding-bottom: 20px; | |
padding-top: 20px; | |
/* font-style: italic; */ | |
} | |
.claim-header { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
font-size: 1.5em; | |
font-weight: normal; | |
padding-bottom: 10px; | |
padding-top: 10px; | |
/* font-style: italic; */ | |
} | |
.claim-doc-title { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
font-size: 1.25em; | |
font-weight: normal; | |
padding-left: 20px; | |
padding-bottom: 5px; | |
padding-top: 10px; | |
/* font-style: italic; */ | |
} | |
.claim-doc-url { | |
/* font-family: cursive, sans-serif; */ | |
font-size: 0.75em; | |
padding-left: 20px; | |
padding-bottom: 10px; | |
padding-top: 0px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.claim-determination { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
font-size: 1em; | |
font-weight: normal; | |
padding-left: 60px; | |
padding-bottom: 10px; | |
/* font-style: italic; */ | |
} | |
.claim-text { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
font-size: 1em; | |
white-space: pre-wrap; | |
padding-left: 80px; | |
text-indent: -20px; | |
padding-bottom: 20px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-title { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
width: 100%; | |
display: inline-block; | |
font-size: 2em; | |
font-weight: bolder; | |
padding-top: 20px; | |
/* font-style: italic; */ | |
} | |
.doc-url { | |
/* font-family: cursive, sans-serif; */ | |
font-size: 1em; | |
padding-left: 40px; | |
padding-bottom: 10px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-text { | |
/* font-family: cursive, sans-serif; */ | |
font-family: Optima, sans-serif; | |
font-size: 1.5em; | |
white-space: pre-wrap; | |
padding-left: 40px; | |
padding-bottom: 20px; | |
/* font-weight: bolder; */ | |
/* font-style: italic; */ | |
} | |
.doc-text .chunk-separator { | |
/* font-style: italic; */ | |
color: #0000FF; | |
} | |
.doc-title > img { | |
width: 22px; | |
height: 22px; | |
border-radius: 50%; | |
overflow: hidden; | |
background-color: transparent; | |
display: inline-block; | |
vertical-align: middle; | |
} | |
.doc-title > score { | |
font-family: Optima, sans-serif; | |
font-weight: normal; | |
float: right; | |
} | |
</style> | |
""" | |