Spaces:
Running
Running
import re,torch,gradio as gr | |
from transformers import AutoTokenizer,AutoModelForSequenceClassification | |
#constants are values that never change while the program runs | |
KModelId = "LikoKIko/OpenCensor-Hebrew" #the AI model from the HF | |
KMaxLength = 256 #maximum number of words the AI can read at one time | |
KThreshold = 0.50 #if AI score is above this number text contains bad words | |
KDevice = "cuda" if torch.cuda.is_available() else "cpu" | |
#download and load the AI model | |
tok = AutoTokenizer.from_pretrained(KModelId) #tokenizer turns words into numbers | |
#load the AI model | |
model = AutoModelForSequenceClassification.from_pretrained( | |
KModelId,num_labels = 1).to(KDevice).eval() #load AI model put it on graphics card | |
#text cleaning function | |
clean = lambda s:re.sub(r"\s+"," ",str(s)).strip() #removes extra spaces (example: "hello world" -> "hello world") | |
#makes AI run faster by skipping gradient tracking and using less memory | |
def check(txt:str)->str: | |
txt = clean(txt) #clean the text | |
if not txt: return "Type something first." #error if empty | |
#prepare text for AI | |
batch = tok(txt,return_tensors = "pt",truncation = True, | |
padding = True,max_length = KMaxLength).to(KDevice) | |
prob = torch.sigmoid(model(**batch).logits).item() #get AI score | |
label = 1 if prob >= KThreshold else 0 #decide if bad or good | |
return f"Prob:{prob:.4f}|Label:{label}" #return result | |
#create web interface | |
with gr.Blocks(title = "Hebrew Profanity Detector") as demo: #main interface | |
inp = gr.Textbox(lines = 4,label = "Hebrew text") #input box | |
out = gr.Textbox(label = "Result") #result text box | |
#button that runs check function | |
gr.Button("Check").click(check,inp,out,api_name="/predict") | |
#text examples | |
gr.Examples([["ืื ืืืจ ืืฆืืื"],["ืืฉ ืื ืืจื ืืืจ"]], | |
inputs = inp,outputs = out,fn = check,cache_examples = False) | |
#run the program | |
if __name__ == "__main__": | |
#start server | |
demo.launch(server_name = "0.0.0.0",server_port = 7860,show_error = True) | |