rifatramadhani commited on
Commit
4de50d8
1 Parent(s): adca723

feat: basic topic classification

Browse files
Files changed (1) hide show
  1. app.py +47 -4
app.py CHANGED
@@ -1,7 +1,50 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import torch
4
+ from transformers import pipeline
5
+ import datetime
6
+ import json
7
+ import logging
8
 
9
+ model_path = "cardiffnlp/twitter-roberta-base-dec2021-tweet-topic-multi-all"
10
+ # Load model for first time cache
11
+ topic_classification_task = pipeline("text-classification", model=model_path, tokenizer=model_path)
12
 
13
+ @spaces.GPU
14
+ def classify(query):
15
+ torch_device = 0 if torch.cuda.is_available() else -1
16
+ tokenizer_kwargs = {'truncation':True,'max_length':512}
17
+
18
+ topic_classification_task = pipeline("text-classification", model=model_path, tokenizer=model_path, device=torch_device)
19
+
20
+ request_type = type(query)
21
+ try:
22
+ data = json.loads(query)
23
+ if type(data) != list:
24
+ data = [query]
25
+ else:
26
+ request_type = type(data)
27
+ except Exception as e:
28
+ print(e)
29
+ data = [query]
30
+ pass
31
+
32
+ start_time = datetime.datetime.now()
33
+
34
+ result = topic_classification_task(data, batch_size=128, top_k=3, **tokenizer_kwargs)
35
+
36
+ end_time = datetime.datetime.now()
37
+ elapsed_time = end_time - start_time
38
+
39
+ logging.debug("elapsed predict time: %s", str(elapsed_time))
40
+ print("elapsed predict time:", str(elapsed_time))
41
+
42
+ output = {}
43
+ output["time"] = str(elapsed_time)
44
+ output["device"] = torch_device
45
+ output["result"] = result
46
+
47
+ return json.dumps(output)
48
+
49
+ demo = gr.Interface(fn=classify, inputs=["text"], outputs="text")
50
+ demo.launch()