iwatats's picture
Update app.py
ad10c76 verified
import os
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import gradio as gr
from openpyxl import load_workbook
from openpyxl.formatting.rule import ColorScaleRule
from openpyxl.utils import get_column_letter
from transformers import BertTokenizer, BertModel
import torch
from sentence_transformers import SentenceTransformer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
sbert_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') # Example SBERT model
legal_bert_tokenizer = BertTokenizer.from_pretrained('nlpaueb/legal-bert-base-uncased')
legal_bert_model = BertModel.from_pretrained('nlpaueb/legal-bert-base-uncased')
def compute_cosine_similarity(files, method):
corpus = []
names = []
for file in files:
file_name = os.path.splitext(os.path.basename(file.name))[0]
formatted_name = file_name.lower().title()
names.append(formatted_name)
content = file
corpus.append(content)
if method == "TF-IDF":
tf = TfidfVectorizer(analyzer="word")
tfidf_matrix = tf.fit_transform(corpus)
cosine_sim = cosine_similarity(tfidf_matrix, tfidf_matrix)
elif method == "BERT":
def get_bert_embeddings(texts):
embeddings = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
return embeddings
embeddings = get_bert_embeddings(corpus)
cosine_sim = cosine_similarity(embeddings, embeddings)
elif method == "SBERT":
embeddings = sbert_model.encode(corpus)
cosine_sim = cosine_similarity(embeddings, embeddings)
elif method == "Legal-BERT":
def get_legal_bert_embeddings(texts):
embeddings = []
for text in texts:
inputs = legal_bert_tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = legal_bert_model(**inputs)
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
return embeddings
embeddings = get_legal_bert_embeddings(corpus)
cosine_sim = cosine_similarity(embeddings, embeddings)
df_cosine = pd.DataFrame(data=cosine_sim, columns=names)
df_cosine.insert(loc=0, column="", value=names)
pairs = []
for i in range(len(names)):
for j in range(i + 1, len(names)):
pairs.append({
'Institution 1': names[i],
'Institution 2': names[j],
'Similarity Score': cosine_sim[i, j]
})
df_pairs = pd.DataFrame(pairs)
output_file = f"cosine_similarity_matrix_{method.lower()}.xlsx"
with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
df_cosine.to_excel(writer, index=False, sheet_name='Cosine Similarity Matrix')
df_pairs.to_excel(writer, index=False, sheet_name='Pairwise Similarity Scores')
wb = load_workbook(output_file)
ws = wb['Cosine Similarity Matrix']
last_col = get_column_letter(len(names) + 1)
color_scale_rule = ColorScaleRule(
start_type="percentile", start_value=0, start_color="00FF00", # Green for lower values
mid_type="percentile", mid_value=50, mid_color="FFFF00", # Yellow for mid values
end_type="percentile", end_value=100, end_color="FF0000" # Red for higher values
)
ws.conditional_formatting.add(f'B2:{last_col}{len(names) + 1}', color_scale_rule)
wb.save(output_file)
return output_file
def tfidf_handler(files):
return compute_cosine_similarity(files, "TF-IDF")
def bert_handler(files):
return compute_cosine_similarity(files, "BERT")
def sbert_handler(files):
return compute_cosine_similarity(files, "SBERT")
def legal_bert_handler(files):
return compute_cosine_similarity(files, "Legal-BERT")
interface = gr.Blocks()
with interface:
# Add Title and Explanation at the Top
gr.Markdown("# Mandate Overlaps")
gr.Markdown("""
Upload multiple `.txt` files on the left and choose a method from TF-IDF, BERT, SBERT, or LEGAL-BERT on the right to compute and download an Excel file with both the cosine similarity matrix and pairwise similaity scores for all pairs.
""")
with gr.Row(): # Use a row to align content side-by-side
with gr.Column(scale=1): # Left column (for file input)
file_input = gr.File(file_types=["text"], file_count="multiple", label="Upload .txt files")
with gr.Column(scale=1): # Right column (for buttons and output)
tfidf_button = gr.Button("TF-IDF")
bert_button = gr.Button("BERT")
sbert_button = gr.Button("SBERT")
legal_bert_button = gr.Button("LEGAL-BERT")
output_file = gr.File(label="Download Excel")
tfidf_button.click(tfidf_handler, inputs=file_input, outputs=output_file)
bert_button.click(bert_handler, inputs=file_input, outputs=output_file)
sbert_button.click(sbert_handler, inputs=file_input, outputs=output_file)
legal_bert_button.click(legal_bert_handler, inputs=file_input, outputs=output_file)
interface.launch(share=True)