rifatramadhani's picture
feat: basic topic classification
4de50d8
import gradio as gr
import spaces
import torch
from transformers import pipeline
import datetime
import json
import logging
model_path = "cardiffnlp/twitter-roberta-base-dec2021-tweet-topic-multi-all"
# Load model for first time cache
topic_classification_task = pipeline("text-classification", model=model_path, tokenizer=model_path)
@spaces.GPU
def classify(query):
torch_device = 0 if torch.cuda.is_available() else -1
tokenizer_kwargs = {'truncation':True,'max_length':512}
topic_classification_task = pipeline("text-classification", model=model_path, tokenizer=model_path, device=torch_device)
request_type = type(query)
try:
data = json.loads(query)
if type(data) != list:
data = [query]
else:
request_type = type(data)
except Exception as e:
print(e)
data = [query]
pass
start_time = datetime.datetime.now()
result = topic_classification_task(data, batch_size=128, top_k=3, **tokenizer_kwargs)
end_time = datetime.datetime.now()
elapsed_time = end_time - start_time
logging.debug("elapsed predict time: %s", str(elapsed_time))
print("elapsed predict time:", str(elapsed_time))
output = {}
output["time"] = str(elapsed_time)
output["device"] = torch_device
output["result"] = result
return json.dumps(output)
demo = gr.Interface(fn=classify, inputs=["text"], outputs="text")
demo.launch()