abnerguzman's picture
Update app.py
84b9cf4 verified
raw
history blame
No virus
5.78 kB
import os
import pandas as pd
import json
import pickle
import pprint
import textwrap
import time
from tqdm.autonotebook import tqdm
from pinecone import Pinecone, ServerlessSpec
pc = Pinecone(api_key=os.getenv('PINECONE_API_KEY'))
index_name = "prorata-postman-ds-128-v2"
index = pc.Index(index_name)
from openai import OpenAI
openai_client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
def get_embedding(text, model="text-embedding-3-small"):
text = text.replace("\n", " ")
return openai_client.embeddings.create(input = [text], model=model).data[0].embedding
from transformers import AutoTokenizer, AutoModel
# Load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("colbert-ir/colbertv2.0")
model = AutoModel.from_pretrained("colbert-ir/colbertv2.0")
with open('colbertv2_pc_128_d.pkl', 'rb') as f:
colbertv2_pc_128_d = pickle.load(f)
version_notes = colbertv2_pc_128_d['version_notes']
chunkid_to_colbertv2 = colbertv2_pc_128_d['chunkid_to_colbertv2']
import torch
# Function to compute MaxSim
def maxsim(query_embedding, document_embedding):
# Expand dimensions for broadcasting
# Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]
# Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]
expanded_query = query_embedding.unsqueeze(2)
expanded_doc = document_embedding.unsqueeze(1)
# Compute cosine similarity across the embedding dimension
sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1)
# Take the maximum similarity for each query token (across all document tokens)
# sim_matrix shape: [batch_size, query_length, doc_length]
max_sim_scores, _ = torch.max(sim_matrix, dim=2)
# Average these maximum scores across all query tokens
avg_max_sim = torch.mean(max_sim_scores, dim=1)
return avg_max_sim
def get_matches_reranked(q, k=20):
matches = index.query(vector=get_embedding(q), top_k=k, include_metadata=True)['matches']
q_encoding = tokenizer(q, return_tensors='pt')
q_embedding = model(**q_encoding).last_hidden_state.mean(dim=1)
# Calculate MaxSim scores
for match in matches:
score = maxsim(q_embedding.unsqueeze(0), chunkid_to_colbertv2[match['id']])
match['colbertv2_score'] = score.item()
matches_colbertv2 = sorted(matches, key=lambda x: x['colbertv2_score'], reverse=True)
return matches_colbertv2
def filter_matches(matches_colbertv2, score_thr=0.0):
matches_colbertv2_f = []
url_to_chunk_l = {}
for idx, match in enumerate(matches_colbertv2):
if match['colbertv2_score'] > score_thr:
_url = match['metadata']['url']
if not _url in url_to_chunk_l:
url_to_chunk_l[_url] = []
url_to_chunk_l[_url].append(match)
matches_colbertv2_f.append(match)
return matches_colbertv2_f
style_str = """
<style>
.doc-title {
/* font-family: cursive, sans-serif; */
font-family: Optima, sans-serif;
width: 100%;
display: inline-block;
font-size: 2em;
font-weight: bolder;
padding-top: 20px;
/* font-style: italic; */
}
.doc-url {
/* font-family: cursive, sans-serif; */
font-size: 1em;
padding-left: 40px;
padding-bottom: 10px;
/* font-weight: bolder; */
/* font-style: italic; */
}
.doc-text {
/* font-family: cursive, sans-serif; */
font-family: Optima, sans-serif;
font-size: 1.5em;
padding-left: 40px;
padding-bottom: 20px;
/* font-weight: bolder; */
/* font-style: italic; */
}
.doc-title > img {
width: 22px;
height: 22px;
border-radius: 50%;
overflow: hidden;
background-color: transparent;
display: inline-block;
vertical-align: middle;
}
.doc-title > score {
font-family: Optima, sans-serif;
font-weight: normal;
float: right;
}
</style>
"""
import gradio as gr
from io import StringIO
from urllib.parse import urlparse
def output_chunks_reranked(msg):
matches_colbertv2 = get_matches_reranked(msg, k=20)
matches_colbertv2 = filter_matches(matches_colbertv2, score_thr=0.55)
_out = StringIO()
if not matches_colbertv2:
print(style_str, file=_out)
print(f"<div>", file=_out)
print(f"<div class=\"doc-title\">No sources relevant to this target were found.</div>", file=_out)
print(f"</div>", file=_out)
return _out.getvalue()
for idx, match in enumerate(matches_colbertv2):
print(style_str, file=_out)
print(f"<div>", file=_out)
favicon = f"<img src=\"https://www.google.com/s2/favicons?sz=128&amp;domain={urlparse(match['metadata']['url']).netloc}\"/>"
print(f"<div class=\"doc-title\">{favicon}&nbsp&nbsp;{match['metadata']['title']}<score>{match['colbertv2_score']:.2f}</score></div>", file=_out)
print(f"<div class=\"doc-url\"><a href=\"{match['metadata']['url']}\" target=\"_blank\">{match['metadata']['url']}</a></div>", file=_out)
# print(f" (Score: {match['colbertv2_score']:.2f})", file=_out)
print(f"<div class=\"doc-text\">{match['metadata']['text']}</div>", file=_out)
print(f"</div>", file=_out)
return _out.getvalue()
with gr.Blocks() as demo:
msg = gr.Textbox(label='Target')
# results_box = gr.Textbox(label='Matches', lines=30, autoscroll=False)
results_box = gr.HTML(label='Matches')
msg.submit(output_chunks_reranked, msg, results_box, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch()