|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
|
|
model_name = "dejanseo/Intent-XS" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
model.eval() |
|
|
|
|
|
label_map = { |
|
1: 'Commercial', |
|
2: 'Non-Commercial', |
|
3: 'Branded', |
|
4: 'Non-Branded', |
|
5: 'Informational', |
|
6: 'Navigational', |
|
7: 'Transactional', |
|
8: 'Commercial Investigation', |
|
9: 'Local', |
|
10: 'Entertainment' |
|
} |
|
|
|
|
|
def get_predictions(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.sigmoid(logits).squeeze() |
|
predictions = (probabilities > 0.5).int() |
|
return probabilities.numpy(), predictions.numpy() |
|
|
|
|
|
st.title('Multi-label Classification with Intent-XS') |
|
query = st.text_input("Enter your query:") |
|
|
|
if st.button('Submit'): |
|
if query: |
|
probabilities, predictions = get_predictions(query) |
|
result = {label_map[i+1]: f"Probability: {prob:.2f}" for i, prob in enumerate(probabilities) if predictions[i] == 1} |
|
if result: |
|
st.write("Predicted Categories:") |
|
for label, prob in result.items(): |
|
st.write(f"{label}: {prob}") |
|
else: |
|
st.write("No relevant categories predicted.") |
|
else: |
|
st.write("Please enter a query to get predictions.") |
|
|