Jerry0203's picture
Update app.py
7bd2349
import gradio as gr
from tqdm.autonotebook import tqdm
import ast
import nltk
from sentence_transformers import SentenceTransformer, util
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
import glob
import torch
import pickle
import zipfile
from scipy.sparse.csgraph import connected_components
from scipy.special import softmax
import logging
import re
logger = logging.getLogger(__name__)
def degree_centrality_scores(
similarity_matrix,
threshold=None,
increase_power=True,
):
if not (
threshold is None
or isinstance(threshold, float)
and 0 <= threshold < 1
):
raise ValueError(
'\'threshold\' should be a floating-point number '
'from the interval [0, 1) or None',
)
if threshold is None:
markov_matrix = create_markov_matrix(similarity_matrix)
else:
markov_matrix = create_markov_matrix_discrete(
similarity_matrix,
threshold,
)
scores = stationary_distribution(
markov_matrix,
increase_power=increase_power,
normalized=False,
)
return scores
def _power_method(transition_matrix, increase_power=True, max_iter=10000):
eigenvector = np.ones(len(transition_matrix))
if len(eigenvector) == 1:
return eigenvector
transition = transition_matrix.transpose()
for _ in range(max_iter):
eigenvector_next = np.dot(transition, eigenvector)
if np.allclose(eigenvector_next, eigenvector):
return eigenvector_next
eigenvector = eigenvector_next
if increase_power:
transition = np.dot(transition, transition)
logger.warning("Maximum number of iterations for power method exceeded without convergence!")
return eigenvector_next
def connected_nodes(matrix):
_, labels = connected_components(matrix)
groups = []
for tag in np.unique(labels):
group = np.where(labels == tag)[0]
groups.append(group)
return groups
def create_markov_matrix(weights_matrix):
n_1, n_2 = weights_matrix.shape
if n_1 != n_2:
raise ValueError('\'weights_matrix\' should be square')
row_sum = weights_matrix.sum(axis=1, keepdims=True)
# normalize probability distribution differently if we have negative transition values
if np.min(weights_matrix) <= 0:
return softmax(weights_matrix, axis=1)
return weights_matrix / row_sum
def create_markov_matrix_discrete(weights_matrix, threshold):
discrete_weights_matrix = np.zeros(weights_matrix.shape)
ixs = np.where(weights_matrix >= threshold)
discrete_weights_matrix[ixs] = 1
return create_markov_matrix(discrete_weights_matrix)
def stationary_distribution(
transition_matrix,
increase_power=True,
normalized=True,
):
n_1, n_2 = transition_matrix.shape
if n_1 != n_2:
raise ValueError('\'transition_matrix\' should be square')
distribution = np.zeros(n_1)
grouped_indices = connected_nodes(transition_matrix)
for group in grouped_indices:
t_matrix = transition_matrix[np.ix_(group, group)]
eigenvector = _power_method(t_matrix, increase_power=increase_power)
distribution[group] = eigenvector
if normalized:
distribution /= n_1
return distribution
def cut_sent(para):
para = re.sub('([γ€‚οΌοΌŸ\?])([^”’])', r"\1\n\2", para)
para = re.sub('(\.{6})([^”’])', r"\1\n\2", para)
para = re.sub('(\…{2})([^”’])', r"\1\n\2", para)
para = re.sub('([γ€‚οΌοΌŸ\?][”’])([^οΌŒγ€‚οΌοΌŸ\?])', r'\1\n\2', para)
para = para.rstrip()
return para.split("\n")
def embed(document):
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
sentences = cut_sent(document)
embeddings = model.encode(sentences, convert_to_tensor=True)
#Compute the pair-wise cosine similarities
cos_scores = util.pytorch_cos_sim(embeddings, embeddings).cpu().numpy()
#Compute the centrality for each sentence
centrality_scores = degree_centrality_scores(cos_scores, threshold=None)
#We argsort so that the first element is the sentence with the highest score
most_central_sentence_indices = np.argsort(-centrality_scores)
response = sentences[most_central_sentence_indices[0]]
return response
def search(query):
model = SentenceTransformer('clip-ViT-B-32-multilingual-v1')
img_folder = 'photos/'
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
os.makedirs(img_folder, exist_ok=True)
photo_filename = 'unsplash-25k-photos.zip'
if not os.path.exists(photo_filename): #Download dataset if does not exist
util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)
#Extract all images
with zipfile.ZipFile(photo_filename, 'r') as zf:
for member in tqdm(zf.infolist(), desc='Extracting'):
zf.extract(member, img_folder)
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
if not os.path.exists(emb_filename): #Download dataset if does not exist
util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename)
with open(emb_filename, 'rb') as fIn:
img_names, img_emb = pickle.load(fIn)
query_emb = model.encode([query], convert_to_tensor=True, show_progress_bar=False)
hits = util.semantic_search(query_emb, img_emb, top_k=1)[0]
for hit in hits:
return plt.imread(os.path.join(img_folder, img_names[hit['corpus_id']]))
def sentence_embedding(sentences):
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embeddings = model.encode(sentences)
return embeddings
def sentence_sim(sentence1, sentence2):
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embedding1 = model.encode(sentence1)
embedding2 = model.encode(sentence2)
cos_scores = util.pytorch_cos_sim(embedding1, embedding2).cpu().numpy()
return cos_scores[0][0]
with gr.Blocks() as demo:
with gr.Tab("Text Summarization"):
gr.Markdown("Give a long document, find the sentence that give a good and short summary of the content.")
text_input = gr.Textbox(label="document")
text_output = gr.Textbox(label="summatization")
text_button = gr.Button("Summarize")
with gr.Tab("Image Search"):
gr.Markdown("Image search given a user query.")
with gr.Row():
image_input = gr.Textbox(label="query")
image_output = gr.Image(label="image")
image_button = gr.Button("Search")
with gr.Tab("Sentence Embedding"):
gr.Markdown("Embed the given sentence.")
embed_input = gr.Textbox(label="sentence")
embed_output = gr.Textbox(label="embedding")
embed_button = gr.Button("Embed")
with gr.Tab("Sentence Similarity"):
gr.Markdown("Calculate the similarity of two sentences.")
sim_input1 = gr.Textbox(label="sentence_1")
sim_input2 = gr.Textbox(label="sentence_2")
sim_output = gr.Textbox(label="similarity")
sim_button = gr.Button("Calculate")
text_button.click(embed, inputs=text_input, outputs=text_output)
image_button.click(search, inputs=image_input, outputs=image_output)
embed_button.click(sentence_embedding, inputs=embed_input, outputs=embed_output)
sim_button.click(sentence_sim, inputs=[sim_input1, sim_input2], outputs=sim_output)
demo.launch()