import streamlit as st from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast import torch # Set page configuration st.set_page_config(page_title="Spam Detection", page_icon="📧") # Use an absolute path to avoid ambiguity import os model_path = os.path.join(os.getcwd(), "fine_tuned_model") # Ensure it's treated as a local directory model = DistilBertForSequenceClassification.from_pretrained(model_path) tokenizer = DistilBertTokenizerFast.from_pretrained(model_path) # Function to predict spam def predict_spam(text): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits prediction = torch.argmax(logits, dim=-1).item() return "Spam" if prediction == 1 else "Not Spam" # Streamlit app def main(): st.title("Spam Detection") st.write("This is a Spam Detection App using a fine-tuned DistilBERT model.") # Input text box message = st.text_area("Enter message to classify as spam or not:") if st.button("Predict"): if message: prediction = predict_spam(message) st.write(f"The message is: **{prediction}**") else: st.write("Please enter a message to classify.") if __name__ == "__main__": main()