|
import unittest |
|
|
|
from transformers.testing_utils import Expectations |
|
|
|
|
|
class ExpectationsTest(unittest.TestCase): |
|
def test_expectations(self): |
|
expectations = Expectations( |
|
{ |
|
(None, None): 1, |
|
("cuda", 8): 2, |
|
("cuda", 7): 3, |
|
("rocm", 8): 4, |
|
("rocm", None): 5, |
|
("cpu", None): 6, |
|
("xpu", 3): 7, |
|
} |
|
) |
|
|
|
def check(value, key): |
|
assert expectations.find_expectation(key) == value |
|
|
|
|
|
check(1, ("npu", None)) |
|
check(7, ("xpu", 3)) |
|
check(2, ("cuda", 8)) |
|
check(3, ("cuda", 7)) |
|
check(4, ("rocm", 9)) |
|
check(4, ("rocm", None)) |
|
check(2, ("cuda", 2)) |
|
|
|
expectations = Expectations({("cuda", 8): 1}) |
|
with self.assertRaises(ValueError): |
|
expectations.find_expectation(("xpu", None)) |
|
|