import streamlit as st from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch # Load tokenizer and model from Hugging Face model hub model_name = "dejanseo/Intent-XS" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.eval() # Set the model to evaluation mode # Human-readable labels label_map = { 1: 'Commercial', 2: 'Non-Commercial', 3: 'Branded', 4: 'Non-Branded', 5: 'Informational', 6: 'Navigational', 7: 'Transactional', 8: 'Commercial Investigation', 9: 'Local', 10: 'Entertainment' } # Function to perform inference def get_predictions(text): inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.sigmoid(logits).squeeze() predictions = (probabilities > 0.5).int() return probabilities.numpy(), predictions.numpy() # Streamlit user interface st.title('Multi-label Classification with Intent-XS') query = st.text_input("Enter your query:") if st.button('Submit'): if query: probabilities, predictions = get_predictions(query) result = {label_map[i+1]: f"Probability: {prob:.2f}" for i, prob in enumerate(probabilities) if predictions[i] == 1} if result: st.write("Predicted Categories:") for label, prob in result.items(): st.write(f"{label}: {prob}") else: st.write("No relevant categories predicted.") else: st.write("Please enter a query to get predictions.")