Intent-XS / example.py
dejanseo's picture
Upload example.py
a24e9d7 verified
raw
history blame
1.71 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load tokenizer and model from Hugging Face model hub
model_name = "dejanseo/Intent-XS"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval() # Set the model to evaluation mode
# Human-readable labels
label_map = {
1: 'Commercial',
2: 'Non-Commercial',
3: 'Branded',
4: 'Non-Branded',
5: 'Informational',
6: 'Navigational',
7: 'Transactional',
8: 'Commercial Investigation',
9: 'Local',
10: 'Entertainment'
}
# Function to perform inference
def get_predictions(text):
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.sigmoid(logits).squeeze()
predictions = (probabilities > 0.5).int()
return probabilities.numpy(), predictions.numpy()
# Streamlit user interface
st.title('Multi-label Classification with Intent-XS')
query = st.text_input("Enter your query:")
if st.button('Submit'):
if query:
probabilities, predictions = get_predictions(query)
result = {label_map[i+1]: f"Probability: {prob:.2f}" for i, prob in enumerate(probabilities) if predictions[i] == 1}
if result:
st.write("Predicted Categories:")
for label, prob in result.items():
st.write(f"{label}: {prob}")
else:
st.write("No relevant categories predicted.")
else:
st.write("Please enter a query to get predictions.")