์‚ฌ์šฉ์˜ˆ์‹œ

import onnxruntime as ort
import numpy as np
from transformers import MobileBertTokenizer

# ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ ์„ค์ •
model_path = r'C:\NEW_tinybert\AI\tinybert_model.onnx'  # ONNX ๋ชจ๋ธ ๊ฒฝ๋กœ
tokenizer_path = r'C:\NEW_distilbert\AI'  # ๋กœ์ปฌ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ

# ONNX ๋ชจ๋ธ ์„ธ์…˜ ์ดˆ๊ธฐํ™”
ort_session = ort.InferenceSession(model_path)

# MobileBERT ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = MobileBertTokenizer.from_pretrained(tokenizer_path)

# ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ํ•จ์ˆ˜
def test_model(text):
    """
    ์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ONNX ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ๋ถ„๋ฅ˜ํ•˜๋Š” ํ•จ์ˆ˜
    Args:
        text (str): ๋ถ„์„ํ•  ํ…์ŠคํŠธ
    Returns:
        str: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฉ”์‹œ์ง€
    """
    # ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™” ๋ฐ ONNX ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
    inputs = tokenizer(
        text,
        padding="max_length",  # ์ž…๋ ฅ ๊ธธ์ด๋ฅผ 128๋กœ ๊ณ ์ •
        truncation=True,       # ๊ธด ํ…์ŠคํŠธ๋Š” ์ž˜๋ผ๋ƒ„
        max_length=128,        # ์ตœ๋Œ€ ํ† ํฐ ๊ธธ์ด
        return_tensors="np"    # NumPy ๋ฐฐ์—ด ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜
    )
    
    # NumPy ๋ฐฐ์—ด์„ int64๋กœ ๋ณ€ํ™˜
    input_ids = inputs["input_ids"].astype(np.int64)
    attention_mask = inputs["attention_mask"].astype(np.int64)
    
    # ONNX ๋ชจ๋ธ ์ž…๋ ฅ ์ค€๋น„
    ort_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    }
    
    # ONNX ๋ชจ๋ธ ์ถ”๋ก  ์‹คํ–‰
    outputs = ort_session.run(None, ort_inputs)
    logits = outputs[0]  # ๋ชจ๋ธ ์ถœ๋ ฅ (๋กœ์ง“ ๊ฐ’)

    # ๋กœ์ง“ ๊ฐ’์„ ํ™•๋ฅ ๋กœ ๋ณ€ํ™˜ ๋ฐ ํด๋ž˜์Šค ์˜ˆ์ธก
    predicted_class = np.argmax(logits, axis=1).item()

    # ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
    return "๋กœ๋งจ์Šค ์Šค์บ ์ผ ๊ฐ€๋Šฅ์„ฑ ์žˆ์Œ" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ ์ด ์•„๋‹˜"

# ํ…Œ์ŠคํŠธํ•  ๋Œ€ํ™” ๋‚ด์šฉ
test_texts = [
    "๋„ˆ ์—„๋งˆ ์—†๋ƒ?",
    "์ €๋Š” ๊ธˆ์œต ์ƒํ’ˆ์„ ์†Œ๊ฐœํ•˜๋Š” ์‚ฌ๋žŒ์ž…๋‹ˆ๋‹ค. ํˆฌ์žํ•˜๋ฉด ์ด์ต์ด ํฝ๋‹ˆ๋‹ค.",
    "๋‚ด ๋ณด์ง€๊ฐ€ ๋‹ฌ์•„์˜ฌ๋ž์–ด",
    "๋‚ด ๊ฐ€์Šด ๋งŒ์งˆ๋ž˜??"
]

# ๊ฐ ํ…Œ์ŠคํŠธ ํ…์ŠคํŠธ์— ๋Œ€ํ•ด ๊ฒฐ๊ณผ ์ถœ๋ ฅ
for text in test_texts:
    result = test_model(text)
    print(f"์ž…๋ ฅ: {text} => ๊ฒฐ๊ณผ: {result}")
Downloads last month
5
Safetensors
Model size
14.4M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.