Spaces:
Running
Running
File size: 2,061 Bytes
58b232d c4a722e 58b232d c4a722e 58b232d c4a722e 58b232d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import re,torch,gradio as gr
from transformers import AutoTokenizer,AutoModelForSequenceClassification
#constants are values that never change while the program runs
KModelId = "LikoKIko/OpenCensor-Hebrew" #the AI model from the HF
KMaxLength = 256 #maximum number of words the AI can read at one time
KThreshold = 0.50 #if AI score is above this number text contains bad words
KDevice = "cuda" if torch.cuda.is_available() else "cpu"
#download and load the AI model
tok = AutoTokenizer.from_pretrained(KModelId) #tokenizer turns words into numbers
#load the AI model
model = AutoModelForSequenceClassification.from_pretrained(
KModelId,num_labels = 1).to(KDevice).eval() #load AI model put it on graphics card
#text cleaning function
clean = lambda s:re.sub(r"\s+"," ",str(s)).strip() #removes extra spaces (example: "hello world" -> "hello world")
@torch.inference_mode() #makes AI run faster by skipping gradient tracking and using less memory
def check(txt:str)->str:
txt = clean(txt) #clean the text
if not txt: return "Type something first." #error if empty
#prepare text for AI
batch = tok(txt,return_tensors = "pt",truncation = True,
padding = True,max_length = KMaxLength).to(KDevice)
prob = torch.sigmoid(model(**batch).logits).item() #get AI score
label = 1 if prob >= KThreshold else 0 #decide if bad or good
return f"Prob:{prob:.4f}|Label:{label}" #return result
#create web interface
with gr.Blocks(title = "Hebrew Profanity Detector") as demo: #main interface
inp = gr.Textbox(lines = 4,label = "Hebrew text") #input box
out = gr.Textbox(label = "Result") #result text box
#button that runs check function
gr.Button("Check").click(check,inp,out,api_name="/predict")
#text examples
gr.Examples([["ืื ืืืจ ืืฆืืื"],["ืืฉ ืื ืืจื ืืืจ"]],
inputs = inp,outputs = out,fn = check,cache_examples = False)
#run the program
if __name__ == "__main__":
#start server
demo.launch(server_name = "0.0.0.0",server_port = 7860,show_error = True)
|