long1104 commited on
Commit
e35a63c
1 Parent(s): 283d652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -18
app.py CHANGED
@@ -46,26 +46,46 @@
46
  # main()
47
 
48
  import streamlit as st
 
 
49
 
50
- # Function to categorize input sentences
51
- def categorize_sentence(sentence):
52
- # Replace this function with your own logic to categorize sentences
53
- categories = ['Restaurants', 'Food', 'Travel', 'New York City']
54
- return categories
55
 
56
- # Configure Streamlit layout
57
- st.set_page_config(page_title='Sentence Categorizer', layout='wide')
 
 
 
 
58
 
59
- # Add title and description
60
- st.title('Welcome to Sentence Categorizer')
61
- st.write('Enter a sentence and discover relevant categories!')
 
62
 
63
- # Create input box
64
- sentence = st.text_input('Enter a sentence')
 
 
65
 
66
- # Create button to trigger categorization
67
- if st.button('Categorize'):
68
- st.write('Categories:')
69
- categories = categorize_sentence(sentence)
70
- for category in categories:
71
- st.success(category)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # main()
47
 
48
  import streamlit as st
49
+ from sentence_transformers import SentenceTransformer
50
+ from sklearn.metrics.pairwise import cosine_similarity
51
 
52
+ # Load pre-trained Sentence Transformer model
53
+ model = SentenceTransformer('bert-base-nli-mean-tokens')
 
 
 
54
 
55
+ # Define your dataset here (example categories)
56
+ categories = {
57
+ 'sports': ['football', 'basketball', 'tennis'],
58
+ 'politics': ['election', 'government', 'policy'],
59
+ 'technology': ['AI', 'machine learning', 'data science']
60
+ }
61
 
62
+ # Function to get relevant categories based on user query
63
+ def get_relevant_categories(query):
64
+ query_embedding = model.encode([query])
65
+ category_scores = {}
66
 
67
+ for category, keywords in categories.items():
68
+ keyword_embeddings = model.encode(keywords)
69
+ similarity_scores = cosine_similarity(query_embedding, keyword_embeddings)
70
+ category_scores[category] = sum(similarity_scores)[0]
71
 
72
+ relevant_categories = [category for category, score in sorted(category_scores.items(), key=lambda x: x[1], reverse=True) if score > 0]
73
+ return relevant_categories
74
+
75
+ # Streamlit app layout and UI
76
+ def main():
77
+ st.title("Sentence Categorization App")
78
+ st.write("Enter a sentence to categorize:")
79
+ user_input = st.text_input('', value='', max_chars=None, key=None, type='default')
80
+
81
+ if st.button('Categorize'):
82
+ if user_input:
83
+ relevant_categories = get_relevant_categories(user_input)
84
+ st.write("Relevant Categories:")
85
+ for category in relevant_categories:
86
+ st.write(f"- {category}")
87
+ else:
88
+ st.write("Please enter a sentence for categorization.")
89
+
90
+ if __name__ == "__main__":
91
+ main()