File size: 1,839 Bytes
e2f5f01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Imports
import gradio as gr
from sklearn.linear_model import LogisticRegression
import pickle5 as pickle
import re
import string
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer

# Load pickled model and vectorizer
model = 'lr_021823.pkl'
model_loaded = pickle.load(open(model, 'rb'))
vectorizer = 'vectorizer_021823.pkl'
vectorizer_loaded = pickle.load(open(vectorizer, 'rb'))


# Process input text, including removing stopwords, converting to lowercase, and removing punctuation
stop = stopwords.words('english')
def process_text(text):
    text = [word for word in text.split() if word not in stop]
    text = str(text).lower()
    text = re.sub(
        f"[{re.escape(string.punctuation)}]", " ", text
    )
    text = " ".join(text.split())
    return text

# Vectorize text
def vectorize_text(text):
    text = process_text(text)
    text = vectorizer_loaded.transform([text])
    return text

# Valid input for the model so number of features match
def class_predict(text):
    text = process_text(text)
    vec = vectorizer_loaded.transform([text])
    prediction = model_loaded.predict(vec)
    return prediction


# Define interface
demo = gr.Interface(fn=class_predict,
                        title="Text Classification Demo",
                        description="This is a demo of a text classification model using Logistic Regression.",
                        inputs=gr.Textbox(lines=10, placeholder='Input text here...', label="Input Text"),
                        outputs=gr.Textbox(label="Predicted Label: Healthcare: 0, Other: 1, Technology: 2", lines=2, placeholder='Predicted label will appear here...'),
                        allow_flagging='never'
)

demo.launch()