File size: 892 Bytes
eaebd8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a34d6ee
eaebd8c
 
 
a34d6ee
eaebd8c
 
 
 
 
 
a34d6ee
eaebd8c
 
 
855f01e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import gradio as gr

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("WebOrganizer/TopicClassifier-NoURL")
model = AutoModelForSequenceClassification.from_pretrained(
    "WebOrganizer/TopicClassifier-NoURL",
    trust_remote_code=True,
    use_memory_efficient_attention=False)

def predict(text):
  inputs = tokenizer([text], return_tensors="pt")
  outputs = model(**inputs)
  probs = outputs.logits.softmax(dim=-1)
  pred_index = probs.argmax(dim=-1).item()
  confidence_score = probs[0, pred_index].item()
  id2label = model.config.id2label
  pred_label = id2label[pred_index]

  return {'topic': pred_label, 'confidence': round(confidence_score, 4)}
  
title = "URL content Topic Categorizer"

topic = gr.Interface(
  fn=predict, 
  inputs='text',
  outputs= gr.JSON(),
  title=title,
)

topic.launch(show_error=True)