Spaces:
Runtime error
Runtime error
import gradio as gr | |
from tqdm.autonotebook import tqdm | |
import ast | |
import nltk | |
from sentence_transformers import SentenceTransformer, util | |
import numpy as np | |
import os | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import glob | |
import torch | |
import pickle | |
import zipfile | |
from scipy.sparse.csgraph import connected_components | |
from scipy.special import softmax | |
import logging | |
import re | |
logger = logging.getLogger(__name__) | |
def degree_centrality_scores( | |
similarity_matrix, | |
threshold=None, | |
increase_power=True, | |
): | |
if not ( | |
threshold is None | |
or isinstance(threshold, float) | |
and 0 <= threshold < 1 | |
): | |
raise ValueError( | |
'\'threshold\' should be a floating-point number ' | |
'from the interval [0, 1) or None', | |
) | |
if threshold is None: | |
markov_matrix = create_markov_matrix(similarity_matrix) | |
else: | |
markov_matrix = create_markov_matrix_discrete( | |
similarity_matrix, | |
threshold, | |
) | |
scores = stationary_distribution( | |
markov_matrix, | |
increase_power=increase_power, | |
normalized=False, | |
) | |
return scores | |
def _power_method(transition_matrix, increase_power=True, max_iter=10000): | |
eigenvector = np.ones(len(transition_matrix)) | |
if len(eigenvector) == 1: | |
return eigenvector | |
transition = transition_matrix.transpose() | |
for _ in range(max_iter): | |
eigenvector_next = np.dot(transition, eigenvector) | |
if np.allclose(eigenvector_next, eigenvector): | |
return eigenvector_next | |
eigenvector = eigenvector_next | |
if increase_power: | |
transition = np.dot(transition, transition) | |
logger.warning("Maximum number of iterations for power method exceeded without convergence!") | |
return eigenvector_next | |
def connected_nodes(matrix): | |
_, labels = connected_components(matrix) | |
groups = [] | |
for tag in np.unique(labels): | |
group = np.where(labels == tag)[0] | |
groups.append(group) | |
return groups | |
def create_markov_matrix(weights_matrix): | |
n_1, n_2 = weights_matrix.shape | |
if n_1 != n_2: | |
raise ValueError('\'weights_matrix\' should be square') | |
row_sum = weights_matrix.sum(axis=1, keepdims=True) | |
# normalize probability distribution differently if we have negative transition values | |
if np.min(weights_matrix) <= 0: | |
return softmax(weights_matrix, axis=1) | |
return weights_matrix / row_sum | |
def create_markov_matrix_discrete(weights_matrix, threshold): | |
discrete_weights_matrix = np.zeros(weights_matrix.shape) | |
ixs = np.where(weights_matrix >= threshold) | |
discrete_weights_matrix[ixs] = 1 | |
return create_markov_matrix(discrete_weights_matrix) | |
def stationary_distribution( | |
transition_matrix, | |
increase_power=True, | |
normalized=True, | |
): | |
n_1, n_2 = transition_matrix.shape | |
if n_1 != n_2: | |
raise ValueError('\'transition_matrix\' should be square') | |
distribution = np.zeros(n_1) | |
grouped_indices = connected_nodes(transition_matrix) | |
for group in grouped_indices: | |
t_matrix = transition_matrix[np.ix_(group, group)] | |
eigenvector = _power_method(t_matrix, increase_power=increase_power) | |
distribution[group] = eigenvector | |
if normalized: | |
distribution /= n_1 | |
return distribution | |
def cut_sent(para): | |
para = re.sub('([γοΌοΌ\?])([^ββ])', r"\1\n\2", para) | |
para = re.sub('(\.{6})([^ββ])', r"\1\n\2", para) | |
para = re.sub('(\β¦{2})([^ββ])', r"\1\n\2", para) | |
para = re.sub('([γοΌοΌ\?][ββ])([^οΌγοΌοΌ\?])', r'\1\n\2', para) | |
para = para.rstrip() | |
return para.split("\n") | |
def embed(document): | |
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
sentences = cut_sent(document) | |
embeddings = model.encode(sentences, convert_to_tensor=True) | |
#Compute the pair-wise cosine similarities | |
cos_scores = util.pytorch_cos_sim(embeddings, embeddings).cpu().numpy() | |
#Compute the centrality for each sentence | |
centrality_scores = degree_centrality_scores(cos_scores, threshold=None) | |
#We argsort so that the first element is the sentence with the highest score | |
most_central_sentence_indices = np.argsort(-centrality_scores) | |
response = sentences[most_central_sentence_indices[0]] | |
return response | |
def search(query): | |
model = SentenceTransformer('clip-ViT-B-32-multilingual-v1') | |
img_folder = 'photos/' | |
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0: | |
os.makedirs(img_folder, exist_ok=True) | |
photo_filename = 'unsplash-25k-photos.zip' | |
if not os.path.exists(photo_filename): #Download dataset if does not exist | |
util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename) | |
#Extract all images | |
with zipfile.ZipFile(photo_filename, 'r') as zf: | |
for member in tqdm(zf.infolist(), desc='Extracting'): | |
zf.extract(member, img_folder) | |
emb_filename = 'unsplash-25k-photos-embeddings.pkl' | |
if not os.path.exists(emb_filename): #Download dataset if does not exist | |
util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename) | |
with open(emb_filename, 'rb') as fIn: | |
img_names, img_emb = pickle.load(fIn) | |
query_emb = model.encode([query], convert_to_tensor=True, show_progress_bar=False) | |
hits = util.semantic_search(query_emb, img_emb, top_k=1)[0] | |
for hit in hits: | |
return plt.imread(os.path.join(img_folder, img_names[hit['corpus_id']])) | |
def sentence_embedding(sentences): | |
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
embeddings = model.encode(sentences) | |
return embeddings | |
def sentence_sim(sentence1, sentence2): | |
model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
embedding1 = model.encode(sentence1) | |
embedding2 = model.encode(sentence2) | |
cos_scores = util.pytorch_cos_sim(embedding1, embedding2).cpu().numpy() | |
return cos_scores[0][0] | |
with gr.Blocks() as demo: | |
with gr.Tab("Text Summarization"): | |
gr.Markdown("Give a long document, find the sentence that give a good and short summary of the content.") | |
text_input = gr.Textbox(label="document") | |
text_output = gr.Textbox(label="summatization") | |
text_button = gr.Button("Summarize") | |
with gr.Tab("Image Search"): | |
gr.Markdown("Image search given a user query.") | |
with gr.Row(): | |
image_input = gr.Textbox(label="query") | |
image_output = gr.Image(label="image") | |
image_button = gr.Button("Search") | |
with gr.Tab("Sentence Embedding"): | |
gr.Markdown("Embed the given sentence.") | |
embed_input = gr.Textbox(label="sentence") | |
embed_output = gr.Textbox(label="embedding") | |
embed_button = gr.Button("Embed") | |
with gr.Tab("Sentence Similarity"): | |
gr.Markdown("Calculate the similarity of two sentences.") | |
sim_input1 = gr.Textbox(label="sentence_1") | |
sim_input2 = gr.Textbox(label="sentence_2") | |
sim_output = gr.Textbox(label="similarity") | |
sim_button = gr.Button("Calculate") | |
text_button.click(embed, inputs=text_input, outputs=text_output) | |
image_button.click(search, inputs=image_input, outputs=image_output) | |
embed_button.click(sentence_embedding, inputs=embed_input, outputs=embed_output) | |
sim_button.click(sentence_sim, inputs=[sim_input1, sim_input2], outputs=sim_output) | |
demo.launch() |