|
import gradio |
|
import openai |
|
import os |
|
import shutil |
|
import zipfile |
|
import uuid |
|
import threading |
|
import time |
|
from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance, OpenAI |
|
from sklearn.feature_extraction.text import CountVectorizer |
|
from umap import UMAP |
|
from hdbscan import HDBSCAN |
|
from bertopic import BERTopic |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModel |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
def mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[0] |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
def BETA_run_topic_classification(list_of_sentences: list, n_neighbors=10, n_components=3, min_cluster_size=5): |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
|
|
|
|
encoded_input = tokenizer(list_of_sentences, padding=True, truncation=True, return_tensors='pt') |
|
|
|
|
|
with torch.no_grad(): |
|
model_output = model(**encoded_input) |
|
|
|
|
|
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) |
|
|
|
|
|
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) |
|
|
|
|
|
sentence_embeddings = sentence_embeddings.cpu().numpy() |
|
|
|
|
|
print(f"Embeddings shape: {sentence_embeddings.shape}") |
|
print(f"Embeddings: {sentence_embeddings}") |
|
|
|
umap_model = UMAP(n_neighbors=n_neighbors, n_components=n_components, min_dist=0.0, metric='cosine', random_state=42) |
|
vectorizer_model = CountVectorizer(stop_words="english", min_df=2, ngram_range=(1, 2)) |
|
|
|
hdbscan_model = HDBSCAN(min_cluster_size=min_cluster_size, metric='euclidean', cluster_selection_method='eom', prediction_data=True) |
|
keybert_model = KeyBERTInspired() |
|
mmr_model = MaximalMarginalRelevance(diversity=0.3) |
|
prompt = """ |
|
I have a topic that contains the following documents: |
|
[DOCUMENTS] |
|
The topic is described by the following keywords: [KEYWORDS] |
|
|
|
Based on the information above, extract a short but highly descriptive topic label of at most 5 words. Make sure it is in the following format: |
|
topic: <topic label> |
|
""" |
|
client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) |
|
openai_model = OpenAI(client, model="gpt-3.5-turbo", exponential_backoff=True, chat=True, prompt=prompt) |
|
|
|
representation_model = {"KeyBERT": keybert_model, "OpenAI": openai_model, "MMR": mmr_model} |
|
|
|
topic_model = BERTopic( |
|
embedding_model=model, |
|
umap_model=umap_model, |
|
hdbscan_model=hdbscan_model, |
|
vectorizer_model=vectorizer_model, |
|
representation_model=representation_model, |
|
low_memory=True, |
|
top_n_words=10, |
|
verbose=True |
|
) |
|
|
|
topics, probs = topic_model.fit_transform(list_of_sentences, sentence_embeddings) |
|
chatgpt_topic_labels = {str(topic): " | ".join(list(zip(*values))[0]) for topic, values in topic_model.topic_aspects_["OpenAI"].items()} |
|
chatgpt_topic_labels["-1"] = "Outlier Topic" |
|
topic_model.set_topic_labels(chatgpt_topic_labels) |
|
|
|
topic_distr, _ = topic_model.approximate_distribution(list_of_sentences, window=8, stride=4) |
|
|
|
topics, probs = topic_model.fit_transform(list_of_sentences, sentence_embeddings) |
|
|
|
|
|
print(f"Topics: {topics}") |
|
print(f"Probs: {probs}") |
|
|
|
return topics |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
return str(e), None |
|
|
|
def my_inference_function(sentences): |
|
if not isinstance(sentences, str): |
|
return "Input should be a string of sentences separated by commas." |
|
|
|
sentences_list = [sentence.strip() for sentence in sentences.split(',')] |
|
sentences_list = [sentence for sentence in sentences_list if sentence] |
|
|
|
if not sentences_list: |
|
return "No valid sentences provided." |
|
|
|
topics = BETA_run_topic_classification(sentences_list) |
|
|
|
return topics |
|
|
|
gradio_interface = gradio.Interface( |
|
fn = my_inference_function, |
|
inputs = "text", |
|
outputs = "text" |
|
) |
|
gradio_interface.launch() |