mkoot007 commited on
Commit
18792e2
1 Parent(s): c945277

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -14
app.py CHANGED
@@ -1,19 +1,25 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
4
- model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
5
 
6
- def classify_text(text):
 
 
7
  encoded_text = tokenizer(text, truncation=True, padding='max_length', max_length=512, return_tensors='pt')
8
  predictions = model(**encoded_text)
9
- predicted_label = predictions.logits.argmax(-1).item()
10
- predicted_class = model.config.id2label[predicted_label]
11
-
12
  return predicted_class
13
- interface = gr.Interface(
14
- fn=classify_text,
15
- inputs=[gr.Textbox(label="Input Text")],
16
- outputs=[gr.Textbox(label="Predicted Class")],
17
- title="Text Classification App"
18
- )
19
- interface.launch(share=True)
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
 
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/bert-tiny-finetuned-sms-spam-detection")
5
+ model = AutoModelForSequenceClassification.from_pretrained("mrm8488/bert-tiny-finetuned-sms-spam-detection")
6
+ def classify_spam(text):
7
  encoded_text = tokenizer(text, truncation=True, padding='max_length', max_length=512, return_tensors='pt')
8
  predictions = model(**encoded_text)
9
+ predicted_probabilities = predictions.logits.softmax(dim=1)
10
+ predicted_class = "Spam" if predicted_probabilities[0, 1] > 0.5 else "Not Spam"
 
11
  return predicted_class
12
+ def main():
13
+ st.title("SMS Spam Classification App")
14
+ st.text("Made by Moneeb Ahmad with Lil Love ❤️ ")
15
+ text_input = st.text_area("Enter SMS text for classification:", "")
16
+ if st.button("Classify"):
17
+ if text_input:
18
+ result = classify_spam(text_input)
19
+ st.subheader("Predicted Class:")
20
+ st.write(result)
21
+ else:
22
+ st.warning("Please enter some text for classification.")
23
+
24
+ if __name__ == "__main__":
25
+ main()