ICE3 / app.py
long1104's picture
Update app.py
e35a63c verified
# import streamlit as st
# import transformers
# # Load the pre-trained language model
# model_name = "bert-base-uncased"
# model = transformers.pipeline("text-classification", model=model_name)
# # Streamlit App
# def main():
# st.title("Sentence Category Classifier")
# # Input search sentence
# search_query = st.text_input("Enter a sentence:")
# result = ""
# # Process the search sentence when the user clicks the Search button
# if st.button("Search"):
# if search_query:
# # Classify the sentence using the pre-trained model
# categories = classify_sentence(search_query)
# # Display the categories as output
# if categories:
# result = f"The sentence belongs to the following categories:\n\n"
# for category in categories:
# result += f"• {category}\n"
# else:
# result = "No categories found for the sentence."
# # Display the result
# st.text(result)
# # Function to classify the sentence using the pre-trained language model
# @st.cache(allow_output_mutation=True)
# def classify_sentence(query):
# # Classify the sentence using the pre-trained model
# categories = model(query)
# # Extract the category labels from the model's output
# category_labels = [category['label'] for category in categories]
# return category_labels
# if __name__ == "__main__":
# main()
import streamlit as st
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
# Load pre-trained Sentence Transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# Define your dataset here (example categories)
categories = {
'sports': ['football', 'basketball', 'tennis'],
'politics': ['election', 'government', 'policy'],
'technology': ['AI', 'machine learning', 'data science']
}
# Function to get relevant categories based on user query
def get_relevant_categories(query):
query_embedding = model.encode([query])
category_scores = {}
for category, keywords in categories.items():
keyword_embeddings = model.encode(keywords)
similarity_scores = cosine_similarity(query_embedding, keyword_embeddings)
category_scores[category] = sum(similarity_scores)[0]
relevant_categories = [category for category, score in sorted(category_scores.items(), key=lambda x: x[1], reverse=True) if score > 0]
return relevant_categories
# Streamlit app layout and UI
def main():
st.title("Sentence Categorization App")
st.write("Enter a sentence to categorize:")
user_input = st.text_input('', value='', max_chars=None, key=None, type='default')
if st.button('Categorize'):
if user_input:
relevant_categories = get_relevant_categories(user_input)
st.write("Relevant Categories:")
for category in relevant_categories:
st.write(f"- {category}")
else:
st.write("Please enter a sentence for categorization.")
if __name__ == "__main__":
main()