Spaces:
Runtime error
Runtime error
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import neattext.functions as nfx | |
import re | |
import torch | |
import streamlit as st | |
# labels | |
labels = [ | |
'bug', | |
'enhancement', | |
'question' | |
] | |
# Model path | |
# LOCAL | |
# MODEL_DIR = "./model/distil-bert-uncased-finetuned-github-issues/" | |
# REMOTE | |
MODEL_DIR = "ivanlau/distil-bert-uncased-finetuned-github-issues" | |
def load_model(): | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
return model, tokenizer | |
# Helpers | |
reg_obj = re.compile(r'[^\u0000-\u007F]+', re.UNICODE) | |
def is_english_text(text): | |
return (False if reg_obj.match(text) else True) | |
# remove the stopwords, emojis from the text and convert it into lower case | |
def neatify_text(text): | |
text = str(text).lower() | |
text = nfx.remove_stopwords(text) | |
text = nfx.remove_emojis(text) | |
return text | |
def main(): | |
# st UI setting | |
st.set_page_config( | |
page_title="IntelliLabel", | |
page_icon="π·", | |
layout="centered", | |
initial_sidebar_state="auto", | |
) | |
st.title("IntelliLabel") | |
st.write("IntelliLabel is a github issue classification app. It classifies issue into 3 categories (Bug, Enhancement, Question).") | |
# load model | |
with st.spinner("Downloading model (takes ~1 min)"): | |
model, tokenizer = load_model() | |
default_text = "Unable to run Speech2Text example in documentation" | |
text = st.text_area('Enter text here:', value=default_text) | |
submit = st.button('Predict π·') | |
if submit: | |
text = text.strip(" \n\t") | |
if is_english_text(text): | |
text = neatify_text(text) | |
tokenized_sentence = tokenizer(text, return_tensors='pt') | |
output = model(**tokenized_sentence) | |
predictions = torch.nn.functional.softmax(output.logits, dim=-1) | |
_, preds = torch.max(predictions, dim=-1) | |
predicted = labels[preds.item()] | |
predictions = predictions.tolist()[0] | |
c1, c2, c3 = st.columns(3) | |
c1.metric(label="Bug", value=round(predictions[0],3)) | |
c2.metric(label="Enhancement", value=round(predictions[1],3)) | |
c3.metric(label="Question", value=round(predictions[2],3)) | |
st.info("Prediction") | |
st.write(predicted.capitalize()) | |
else: | |
st.error(str("Please input english text.")) | |
if __name__ == '__main__': | |
main() |