chiikawa-yonezu / app.py
Hiroaki OGASAWARA
Upload folder using huggingface_hub
5e0e168
from pprint import pprint
import gradio as gr
import torch
from safetensors import safe_open
from transformers import BertTokenizer
from utils.ClassifierModel import ClassifierModel
def _classify_text(text, model, device, tokenizer, max_length=20):
"""
テキストが、'ちいかわ' と '米津玄師' のどちらに該当するかの確率を出力する。
"""
# テキストをトークナイズし、PyTorchのテンソルに変換
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=max_length,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
pprint(f"inputs: {inputs}")
# モデルの推論
model.eval()
with torch.no_grad():
outputs = model(
inputs["input_ids"].to(device), inputs["attention_mask"].to(device)
)
pprint(f"outputs: {outputs}")
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# 確率の取得
chiikawa_prob = probabilities[0][0].item()
yonezu_prob = probabilities[0][1].item()
return chiikawa_prob, yonezu_prob
is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")
pprint(f"device: {device}")
model_save_path = "models/model.safetensors"
tensors = {}
with safe_open(model_save_path, framework="pt", device="cpu") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
inference_model: torch.nn.Module = ClassifierModel().to(device)
inference_model.load_state_dict(tensors)
tokenizer = BertTokenizer.from_pretrained(
"cl-tohoku/bert-base-japanese-whole-word-masking"
)
def classify_text(text):
chii_prob, yone_prob = _classify_text(text, inference_model, device, tokenizer)
return {"ちいかわ": chii_prob, "米津玄師": yone_prob}
demo = gr.Interface(
fn=classify_text,
inputs="textbox",
outputs="label",
examples=[
"守りたいんだ",
"どうしてどうしてどうして",
"そこから見ていてね",
"ヤンパパン"
],
)
demo.launch(share=True) # Share your demo with just 1 extra parameter 🚀