File size: 3,174 Bytes
283d652
 
f18dc10
283d652
 
 
f18dc10
283d652
 
 
f18dc10
283d652
 
f18dc10
283d652
f18dc10
283d652
 
 
 
 
f18dc10
283d652
 
 
 
 
 
 
f18dc10
283d652
 
 
 
 
 
 
 
f18dc10
283d652
 
f18dc10
283d652
 
 
 
 
 
e35a63c
 
283d652
e35a63c
 
283d652
e35a63c
 
 
 
 
 
283d652
e35a63c
 
 
 
283d652
e35a63c
 
 
 
f18dc10
e35a63c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# import streamlit as st
# import transformers

# # Load the pre-trained language model
# model_name = "bert-base-uncased"
# model = transformers.pipeline("text-classification", model=model_name)

# # Streamlit App
# def main():
#     st.title("Sentence Category Classifier")
    
#     # Input search sentence
#     search_query = st.text_input("Enter a sentence:")
    
#     result = ""
    
#     # Process the search sentence when the user clicks the Search button
#     if st.button("Search"):
#         if search_query:
#             # Classify the sentence using the pre-trained model
#             categories = classify_sentence(search_query)
            
#             # Display the categories as output
#             if categories:
#                 result = f"The sentence belongs to the following categories:\n\n"
#                 for category in categories:
#                     result += f"• {category}\n"
#             else:
#                 result = "No categories found for the sentence."
    
#     # Display the result
#     st.text(result)

# # Function to classify the sentence using the pre-trained language model
# @st.cache(allow_output_mutation=True)
# def classify_sentence(query):
#     # Classify the sentence using the pre-trained model
#     categories = model(query)
    
#     # Extract the category labels from the model's output
#     category_labels = [category['label'] for category in categories]
    
#     return category_labels

# if __name__ == "__main__":
#     main()

import streamlit as st
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Load pre-trained Sentence Transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')

# Define your dataset here (example categories)
categories = {
    'sports': ['football', 'basketball', 'tennis'],
    'politics': ['election', 'government', 'policy'],
    'technology': ['AI', 'machine learning', 'data science']
}

# Function to get relevant categories based on user query
def get_relevant_categories(query):
    query_embedding = model.encode([query])
    category_scores = {}

    for category, keywords in categories.items():
        keyword_embeddings = model.encode(keywords)
        similarity_scores = cosine_similarity(query_embedding, keyword_embeddings)
        category_scores[category] = sum(similarity_scores)[0]

    relevant_categories = [category for category, score in sorted(category_scores.items(), key=lambda x: x[1], reverse=True) if score > 0]
    return relevant_categories

# Streamlit app layout and UI
def main():
    st.title("Sentence Categorization App")
    st.write("Enter a sentence to categorize:")
    user_input = st.text_input('', value='', max_chars=None, key=None, type='default')

    if st.button('Categorize'):
        if user_input:
            relevant_categories = get_relevant_categories(user_input)
            st.write("Relevant Categories:")
            for category in relevant_categories:
                st.write(f"- {category}")
        else:
            st.write("Please enter a sentence for categorization.")

if __name__ == "__main__":
    main()