Spaces:
Sleeping
Sleeping
Add get_protT5
Browse filesSentencepiece is needed for the T5 tokenizer
- poetry.lock +370 -1
- protention/attention.py +18 -0
- pyproject.toml +3 -0
- tests/test_attention.py +15 -1
poetry.lock
CHANGED
@@ -217,6 +217,17 @@ python-versions = ">=3.7"
|
|
217 |
[package.dependencies]
|
218 |
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
[[package]]
|
221 |
name = "colorama"
|
222 |
version = "0.4.6"
|
@@ -315,6 +326,18 @@ python-versions = "*"
|
|
315 |
[package.extras]
|
316 |
devel = ["colorama", "jsonschema", "json-spec", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"]
|
317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
[[package]]
|
319 |
name = "fqdn"
|
320 |
version = "1.5.1"
|
@@ -345,6 +368,33 @@ python-versions = ">=3.7"
|
|
345 |
[package.dependencies]
|
346 |
gitdb = ">=4.0.1,<5"
|
347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
[[package]]
|
349 |
name = "idna"
|
350 |
version = "3.4"
|
@@ -691,6 +741,14 @@ category = "main"
|
|
691 |
optional = false
|
692 |
python-versions = ">=3.7"
|
693 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
694 |
[[package]]
|
695 |
name = "markdown-it-py"
|
696 |
version = "2.2.0"
|
@@ -747,6 +805,20 @@ category = "main"
|
|
747 |
optional = false
|
748 |
python-versions = "*"
|
749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
[[package]]
|
751 |
name = "nbclassic"
|
752 |
version = "0.5.3"
|
@@ -858,6 +930,21 @@ category = "main"
|
|
858 |
optional = false
|
859 |
python-versions = ">=3.5"
|
860 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
861 |
[[package]]
|
862 |
name = "notebook"
|
863 |
version = "6.5.3"
|
@@ -911,6 +998,94 @@ category = "main"
|
|
911 |
optional = false
|
912 |
python-versions = ">=3.8"
|
913 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
914 |
[[package]]
|
915 |
name = "packaging"
|
916 |
version = "23.0"
|
@@ -1239,6 +1414,14 @@ python-versions = ">=3.6"
|
|
1239 |
[package.dependencies]
|
1240 |
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
|
1241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1242 |
[[package]]
|
1243 |
name = "requests"
|
1244 |
version = "2.28.2"
|
@@ -1312,6 +1495,14 @@ nativelib = ["pyobjc-framework-cocoa", "pywin32"]
|
|
1312 |
objc = ["pyobjc-framework-cocoa"]
|
1313 |
win32 = ["pywin32"]
|
1314 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1315 |
[[package]]
|
1316 |
name = "six"
|
1317 |
version = "1.16.0"
|
@@ -1411,6 +1602,17 @@ watchdog = {version = "*", markers = "platform_system != \"Darwin\""}
|
|
1411 |
[package.extras]
|
1412 |
snowflake = ["snowflake-snowpark-python"]
|
1413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1414 |
[[package]]
|
1415 |
name = "terminado"
|
1416 |
version = "0.17.1"
|
@@ -1443,6 +1645,19 @@ webencodings = ">=0.4"
|
|
1443 |
doc = ["sphinx", "sphinx-rtd-theme"]
|
1444 |
test = ["pytest", "isort", "flake8"]
|
1445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1446 |
[[package]]
|
1447 |
name = "toml"
|
1448 |
version = "0.10.2"
|
@@ -1475,6 +1690,36 @@ category = "main"
|
|
1475 |
optional = false
|
1476 |
python-versions = ">=3.5"
|
1477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1478 |
[[package]]
|
1479 |
name = "tornado"
|
1480 |
version = "6.2"
|
@@ -1483,6 +1728,23 @@ category = "main"
|
|
1483 |
optional = false
|
1484 |
python-versions = ">= 3.7"
|
1485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1486 |
[[package]]
|
1487 |
name = "traitlets"
|
1488 |
version = "5.9.0"
|
@@ -1495,6 +1757,88 @@ python-versions = ">=3.7"
|
|
1495 |
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
|
1496 |
test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"]
|
1497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1498 |
[[package]]
|
1499 |
name = "typing-extensions"
|
1500 |
version = "4.5.0"
|
@@ -1639,7 +1983,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
|
|
1639 |
[metadata]
|
1640 |
lock-version = "1.1"
|
1641 |
python-versions = "^3.10"
|
1642 |
-
content-hash = "
|
1643 |
|
1644 |
[metadata.files]
|
1645 |
altair = []
|
@@ -1688,6 +2032,7 @@ certifi = []
|
|
1688 |
cffi = []
|
1689 |
charset-normalizer = []
|
1690 |
click = []
|
|
|
1691 |
colorama = []
|
1692 |
comm = []
|
1693 |
debugpy = []
|
@@ -1707,9 +2052,11 @@ entrypoints = [
|
|
1707 |
exceptiongroup = []
|
1708 |
executing = []
|
1709 |
fastjsonschema = []
|
|
|
1710 |
fqdn = []
|
1711 |
gitdb = []
|
1712 |
gitpython = []
|
|
|
1713 |
idna = []
|
1714 |
importlib-metadata = []
|
1715 |
iniconfig = []
|
@@ -1734,19 +2081,33 @@ jupyter-server = []
|
|
1734 |
jupyter-server-terminals = []
|
1735 |
jupyterlab-pygments = []
|
1736 |
jupyterlab-widgets = []
|
|
|
1737 |
markdown-it-py = []
|
1738 |
markupsafe = []
|
1739 |
matplotlib-inline = []
|
1740 |
mdurl = []
|
1741 |
mistune = []
|
|
|
1742 |
nbclassic = []
|
1743 |
nbclient = []
|
1744 |
nbconvert = []
|
1745 |
nbformat = []
|
1746 |
nest-asyncio = []
|
|
|
1747 |
notebook = []
|
1748 |
notebook-shim = []
|
1749 |
numpy = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1750 |
packaging = []
|
1751 |
pandas = []
|
1752 |
pandocfilters = [
|
@@ -1839,6 +2200,7 @@ pyyaml = [
|
|
1839 |
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
|
1840 |
]
|
1841 |
pyzmq = []
|
|
|
1842 |
requests = []
|
1843 |
rfc3339-validator = []
|
1844 |
rfc3986-validator = []
|
@@ -1848,6 +2210,7 @@ send2trash = [
|
|
1848 |
{file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
|
1849 |
{file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"},
|
1850 |
]
|
|
|
1851 |
six = [
|
1852 |
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
1853 |
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
@@ -1861,8 +2224,10 @@ soupsieve = []
|
|
1861 |
stack-data = []
|
1862 |
stmol = []
|
1863 |
streamlit = []
|
|
|
1864 |
terminado = []
|
1865 |
tinycss2 = []
|
|
|
1866 |
toml = [
|
1867 |
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
1868 |
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
@@ -1873,8 +2238,12 @@ tomli = [
|
|
1873 |
]
|
1874 |
tomlkit = []
|
1875 |
toolz = []
|
|
|
1876 |
tornado = []
|
|
|
1877 |
traitlets = []
|
|
|
|
|
1878 |
typing-extensions = []
|
1879 |
tzdata = []
|
1880 |
tzlocal = []
|
|
|
217 |
[package.dependencies]
|
218 |
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
219 |
|
220 |
+
[[package]]
|
221 |
+
name = "cmake"
|
222 |
+
version = "3.26.0"
|
223 |
+
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
|
224 |
+
category = "main"
|
225 |
+
optional = false
|
226 |
+
python-versions = "*"
|
227 |
+
|
228 |
+
[package.extras]
|
229 |
+
test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)", "pytest-runner (>=2.9)", "pytest-virtualenv (>=1.7.0)", "scikit-build (>=0.10.0)", "setuptools (>=28.0.0)", "virtualenv (>=15.0.3)", "wheel"]
|
230 |
+
|
231 |
[[package]]
|
232 |
name = "colorama"
|
233 |
version = "0.4.6"
|
|
|
326 |
[package.extras]
|
327 |
devel = ["colorama", "jsonschema", "json-spec", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"]
|
328 |
|
329 |
+
[[package]]
|
330 |
+
name = "filelock"
|
331 |
+
version = "3.10.0"
|
332 |
+
description = "A platform independent file lock."
|
333 |
+
category = "main"
|
334 |
+
optional = false
|
335 |
+
python-versions = ">=3.7"
|
336 |
+
|
337 |
+
[package.extras]
|
338 |
+
docs = ["furo (>=2022.12.7)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)", "sphinx (>=6.1.3)"]
|
339 |
+
testing = ["covdefaults (>=2.3)", "coverage (>=7.2.1)", "pytest-cov (>=4)", "pytest-timeout (>=2.1)", "pytest (>=7.2.2)"]
|
340 |
+
|
341 |
[[package]]
|
342 |
name = "fqdn"
|
343 |
version = "1.5.1"
|
|
|
368 |
[package.dependencies]
|
369 |
gitdb = ">=4.0.1,<5"
|
370 |
|
371 |
+
[[package]]
|
372 |
+
name = "huggingface-hub"
|
373 |
+
version = "0.13.2"
|
374 |
+
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
375 |
+
category = "main"
|
376 |
+
optional = false
|
377 |
+
python-versions = ">=3.7.0"
|
378 |
+
|
379 |
+
[package.dependencies]
|
380 |
+
filelock = "*"
|
381 |
+
packaging = ">=20.9"
|
382 |
+
pyyaml = ">=5.1"
|
383 |
+
requests = "*"
|
384 |
+
tqdm = ">=4.42.1"
|
385 |
+
typing-extensions = ">=3.7.4.3"
|
386 |
+
|
387 |
+
[package.extras]
|
388 |
+
all = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
389 |
+
cli = ["InquirerPy (==0.3.4)"]
|
390 |
+
dev = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
391 |
+
fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
|
392 |
+
quality = ["black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)"]
|
393 |
+
tensorflow = ["tensorflow", "pydot", "graphviz"]
|
394 |
+
testing = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow"]
|
395 |
+
torch = ["torch"]
|
396 |
+
typing = ["types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
397 |
+
|
398 |
[[package]]
|
399 |
name = "idna"
|
400 |
version = "3.4"
|
|
|
741 |
optional = false
|
742 |
python-versions = ">=3.7"
|
743 |
|
744 |
+
[[package]]
|
745 |
+
name = "lit"
|
746 |
+
version = "15.0.7"
|
747 |
+
description = "A Software Testing Tool"
|
748 |
+
category = "main"
|
749 |
+
optional = false
|
750 |
+
python-versions = "*"
|
751 |
+
|
752 |
[[package]]
|
753 |
name = "markdown-it-py"
|
754 |
version = "2.2.0"
|
|
|
805 |
optional = false
|
806 |
python-versions = "*"
|
807 |
|
808 |
+
[[package]]
|
809 |
+
name = "mpmath"
|
810 |
+
version = "1.3.0"
|
811 |
+
description = "Python library for arbitrary-precision floating-point arithmetic"
|
812 |
+
category = "main"
|
813 |
+
optional = false
|
814 |
+
python-versions = "*"
|
815 |
+
|
816 |
+
[package.extras]
|
817 |
+
develop = ["pytest (>=4.6)", "pycodestyle", "pytest-cov", "codecov", "wheel"]
|
818 |
+
docs = ["sphinx"]
|
819 |
+
gmpy = ["gmpy2 (>=2.1.0a4)"]
|
820 |
+
tests = ["pytest (>=4.6)"]
|
821 |
+
|
822 |
[[package]]
|
823 |
name = "nbclassic"
|
824 |
version = "0.5.3"
|
|
|
930 |
optional = false
|
931 |
python-versions = ">=3.5"
|
932 |
|
933 |
+
[[package]]
|
934 |
+
name = "networkx"
|
935 |
+
version = "3.0"
|
936 |
+
description = "Python package for creating and manipulating graphs and networks"
|
937 |
+
category = "main"
|
938 |
+
optional = false
|
939 |
+
python-versions = ">=3.8"
|
940 |
+
|
941 |
+
[package.extras]
|
942 |
+
default = ["numpy (>=1.20)", "scipy (>=1.8)", "matplotlib (>=3.4)", "pandas (>=1.3)"]
|
943 |
+
developer = ["pre-commit (>=2.20)", "mypy (>=0.991)"]
|
944 |
+
doc = ["sphinx (==5.2.3)", "pydata-sphinx-theme (>=0.11)", "sphinx-gallery (>=0.11)", "numpydoc (>=1.5)", "pillow (>=9.2)", "nb2plots (>=0.6)", "texext (>=0.6.7)"]
|
945 |
+
extra = ["lxml (>=4.6)", "pygraphviz (>=1.10)", "pydot (>=1.4.2)", "sympy (>=1.10)"]
|
946 |
+
test = ["pytest (>=7.2)", "pytest-cov (>=4.0)", "codecov (>=2.1)"]
|
947 |
+
|
948 |
[[package]]
|
949 |
name = "notebook"
|
950 |
version = "6.5.3"
|
|
|
998 |
optional = false
|
999 |
python-versions = ">=3.8"
|
1000 |
|
1001 |
+
[[package]]
|
1002 |
+
name = "nvidia-cublas-cu11"
|
1003 |
+
version = "11.10.3.66"
|
1004 |
+
description = "CUBLAS native runtime libraries"
|
1005 |
+
category = "main"
|
1006 |
+
optional = false
|
1007 |
+
python-versions = ">=3"
|
1008 |
+
|
1009 |
+
[[package]]
|
1010 |
+
name = "nvidia-cuda-cupti-cu11"
|
1011 |
+
version = "11.7.101"
|
1012 |
+
description = "CUDA profiling tools runtime libs."
|
1013 |
+
category = "main"
|
1014 |
+
optional = false
|
1015 |
+
python-versions = ">=3"
|
1016 |
+
|
1017 |
+
[[package]]
|
1018 |
+
name = "nvidia-cuda-nvrtc-cu11"
|
1019 |
+
version = "11.7.99"
|
1020 |
+
description = "NVRTC native runtime libraries"
|
1021 |
+
category = "main"
|
1022 |
+
optional = false
|
1023 |
+
python-versions = ">=3"
|
1024 |
+
|
1025 |
+
[[package]]
|
1026 |
+
name = "nvidia-cuda-runtime-cu11"
|
1027 |
+
version = "11.7.99"
|
1028 |
+
description = "CUDA Runtime native Libraries"
|
1029 |
+
category = "main"
|
1030 |
+
optional = false
|
1031 |
+
python-versions = ">=3"
|
1032 |
+
|
1033 |
+
[[package]]
|
1034 |
+
name = "nvidia-cudnn-cu11"
|
1035 |
+
version = "8.5.0.96"
|
1036 |
+
description = "cuDNN runtime libraries"
|
1037 |
+
category = "main"
|
1038 |
+
optional = false
|
1039 |
+
python-versions = ">=3"
|
1040 |
+
|
1041 |
+
[[package]]
|
1042 |
+
name = "nvidia-cufft-cu11"
|
1043 |
+
version = "10.9.0.58"
|
1044 |
+
description = "CUFFT native runtime libraries"
|
1045 |
+
category = "main"
|
1046 |
+
optional = false
|
1047 |
+
python-versions = ">=3"
|
1048 |
+
|
1049 |
+
[[package]]
|
1050 |
+
name = "nvidia-curand-cu11"
|
1051 |
+
version = "10.2.10.91"
|
1052 |
+
description = "CURAND native runtime libraries"
|
1053 |
+
category = "main"
|
1054 |
+
optional = false
|
1055 |
+
python-versions = ">=3"
|
1056 |
+
|
1057 |
+
[[package]]
|
1058 |
+
name = "nvidia-cusolver-cu11"
|
1059 |
+
version = "11.4.0.1"
|
1060 |
+
description = "CUDA solver native runtime libraries"
|
1061 |
+
category = "main"
|
1062 |
+
optional = false
|
1063 |
+
python-versions = ">=3"
|
1064 |
+
|
1065 |
+
[[package]]
|
1066 |
+
name = "nvidia-cusparse-cu11"
|
1067 |
+
version = "11.7.4.91"
|
1068 |
+
description = "CUSPARSE native runtime libraries"
|
1069 |
+
category = "main"
|
1070 |
+
optional = false
|
1071 |
+
python-versions = ">=3"
|
1072 |
+
|
1073 |
+
[[package]]
|
1074 |
+
name = "nvidia-nccl-cu11"
|
1075 |
+
version = "2.14.3"
|
1076 |
+
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
1077 |
+
category = "main"
|
1078 |
+
optional = false
|
1079 |
+
python-versions = ">=3"
|
1080 |
+
|
1081 |
+
[[package]]
|
1082 |
+
name = "nvidia-nvtx-cu11"
|
1083 |
+
version = "11.7.91"
|
1084 |
+
description = "NVIDIA Tools Extension"
|
1085 |
+
category = "main"
|
1086 |
+
optional = false
|
1087 |
+
python-versions = ">=3"
|
1088 |
+
|
1089 |
[[package]]
|
1090 |
name = "packaging"
|
1091 |
version = "23.0"
|
|
|
1414 |
[package.dependencies]
|
1415 |
cffi = {version = "*", markers = "implementation_name == \"pypy\""}
|
1416 |
|
1417 |
+
[[package]]
|
1418 |
+
name = "regex"
|
1419 |
+
version = "2022.10.31"
|
1420 |
+
description = "Alternative regular expression module, to replace re."
|
1421 |
+
category = "main"
|
1422 |
+
optional = false
|
1423 |
+
python-versions = ">=3.6"
|
1424 |
+
|
1425 |
[[package]]
|
1426 |
name = "requests"
|
1427 |
version = "2.28.2"
|
|
|
1495 |
objc = ["pyobjc-framework-cocoa"]
|
1496 |
win32 = ["pywin32"]
|
1497 |
|
1498 |
+
[[package]]
|
1499 |
+
name = "sentencepiece"
|
1500 |
+
version = "0.1.97"
|
1501 |
+
description = "SentencePiece python wrapper"
|
1502 |
+
category = "main"
|
1503 |
+
optional = false
|
1504 |
+
python-versions = "*"
|
1505 |
+
|
1506 |
[[package]]
|
1507 |
name = "six"
|
1508 |
version = "1.16.0"
|
|
|
1602 |
[package.extras]
|
1603 |
snowflake = ["snowflake-snowpark-python"]
|
1604 |
|
1605 |
+
[[package]]
|
1606 |
+
name = "sympy"
|
1607 |
+
version = "1.11.1"
|
1608 |
+
description = "Computer algebra system (CAS) in Python"
|
1609 |
+
category = "main"
|
1610 |
+
optional = false
|
1611 |
+
python-versions = ">=3.8"
|
1612 |
+
|
1613 |
+
[package.dependencies]
|
1614 |
+
mpmath = ">=0.19"
|
1615 |
+
|
1616 |
[[package]]
|
1617 |
name = "terminado"
|
1618 |
version = "0.17.1"
|
|
|
1645 |
doc = ["sphinx", "sphinx-rtd-theme"]
|
1646 |
test = ["pytest", "isort", "flake8"]
|
1647 |
|
1648 |
+
[[package]]
|
1649 |
+
name = "tokenizers"
|
1650 |
+
version = "0.13.2"
|
1651 |
+
description = "Fast and Customizable Tokenizers"
|
1652 |
+
category = "main"
|
1653 |
+
optional = false
|
1654 |
+
python-versions = "*"
|
1655 |
+
|
1656 |
+
[package.extras]
|
1657 |
+
dev = ["pytest", "requests", "numpy", "datasets", "black (==22.3)"]
|
1658 |
+
docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
|
1659 |
+
testing = ["pytest", "requests", "numpy", "datasets", "black (==22.3)"]
|
1660 |
+
|
1661 |
[[package]]
|
1662 |
name = "toml"
|
1663 |
version = "0.10.2"
|
|
|
1690 |
optional = false
|
1691 |
python-versions = ">=3.5"
|
1692 |
|
1693 |
+
[[package]]
|
1694 |
+
name = "torch"
|
1695 |
+
version = "2.0.0"
|
1696 |
+
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
1697 |
+
category = "main"
|
1698 |
+
optional = false
|
1699 |
+
python-versions = ">=3.8.0"
|
1700 |
+
|
1701 |
+
[package.dependencies]
|
1702 |
+
filelock = "*"
|
1703 |
+
jinja2 = "*"
|
1704 |
+
networkx = "*"
|
1705 |
+
nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1706 |
+
nvidia-cuda-cupti-cu11 = {version = "11.7.101", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1707 |
+
nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1708 |
+
nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1709 |
+
nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1710 |
+
nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1711 |
+
nvidia-curand-cu11 = {version = "10.2.10.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1712 |
+
nvidia-cusolver-cu11 = {version = "11.4.0.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1713 |
+
nvidia-cusparse-cu11 = {version = "11.7.4.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1714 |
+
nvidia-nccl-cu11 = {version = "2.14.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1715 |
+
nvidia-nvtx-cu11 = {version = "11.7.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1716 |
+
sympy = "*"
|
1717 |
+
triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
1718 |
+
typing-extensions = "*"
|
1719 |
+
|
1720 |
+
[package.extras]
|
1721 |
+
opt-einsum = ["opt-einsum (>=3.3)"]
|
1722 |
+
|
1723 |
[[package]]
|
1724 |
name = "tornado"
|
1725 |
version = "6.2"
|
|
|
1728 |
optional = false
|
1729 |
python-versions = ">= 3.7"
|
1730 |
|
1731 |
+
[[package]]
|
1732 |
+
name = "tqdm"
|
1733 |
+
version = "4.65.0"
|
1734 |
+
description = "Fast, Extensible Progress Meter"
|
1735 |
+
category = "main"
|
1736 |
+
optional = false
|
1737 |
+
python-versions = ">=3.7"
|
1738 |
+
|
1739 |
+
[package.dependencies]
|
1740 |
+
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
1741 |
+
|
1742 |
+
[package.extras]
|
1743 |
+
dev = ["py-make (>=0.1.0)", "twine", "wheel"]
|
1744 |
+
notebook = ["ipywidgets (>=6)"]
|
1745 |
+
slack = ["slack-sdk"]
|
1746 |
+
telegram = ["requests"]
|
1747 |
+
|
1748 |
[[package]]
|
1749 |
name = "traitlets"
|
1750 |
version = "5.9.0"
|
|
|
1757 |
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
|
1758 |
test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"]
|
1759 |
|
1760 |
+
[[package]]
|
1761 |
+
name = "transformers"
|
1762 |
+
version = "4.27.1"
|
1763 |
+
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
1764 |
+
category = "main"
|
1765 |
+
optional = false
|
1766 |
+
python-versions = ">=3.7.0"
|
1767 |
+
|
1768 |
+
[package.dependencies]
|
1769 |
+
filelock = "*"
|
1770 |
+
huggingface-hub = ">=0.11.0,<1.0"
|
1771 |
+
numpy = ">=1.17"
|
1772 |
+
packaging = ">=20.0"
|
1773 |
+
pyyaml = ">=5.1"
|
1774 |
+
regex = "!=2019.12.17"
|
1775 |
+
requests = "*"
|
1776 |
+
tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14"
|
1777 |
+
tqdm = ">=4.27"
|
1778 |
+
|
1779 |
+
[package.extras]
|
1780 |
+
accelerate = ["accelerate (>=0.10.0)"]
|
1781 |
+
all = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "torch (>=1.7,!=1.12.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "accelerate (>=0.10.0)", "decord (==0.6.0)", "av (==9.2.0)"]
|
1782 |
+
audio = ["librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
|
1783 |
+
codecarbon = ["codecarbon (==1.2.0)"]
|
1784 |
+
deepspeed = ["deepspeed (>=0.6.5)", "accelerate (>=0.10.0)"]
|
1785 |
+
deepspeed-testing = ["deepspeed (>=0.6.5)", "accelerate (>=0.10.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "optuna", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
1786 |
+
dev = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "torch (>=1.7,!=1.12.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "accelerate (>=0.10.0)", "decord (==0.6.0)", "av (==9.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "isort (>=5.5.4)", "ruff (>=0.0.241)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "sudachipy (>=0.6.6)", "sudachidict-core (>=20220729)", "rhoknp (>=1.1.0)", "hf-doc-builder", "scikit-learn"]
|
1787 |
+
dev-tensorflow = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "pillow", "isort (>=5.5.4)", "ruff (>=0.0.241)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
|
1788 |
+
dev-torch = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "torch (>=1.7,!=1.12.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "isort (>=5.5.4)", "ruff (>=0.0.241)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "sudachipy (>=0.6.6)", "sudachidict-core (>=20220729)", "rhoknp (>=1.1.0)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
1789 |
+
docs = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "torch (>=1.7,!=1.12.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "accelerate (>=0.10.0)", "decord (==0.6.0)", "av (==9.2.0)", "hf-doc-builder"]
|
1790 |
+
docs_specific = ["hf-doc-builder"]
|
1791 |
+
fairscale = ["fairscale (>0.3)"]
|
1792 |
+
flax = ["jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)"]
|
1793 |
+
flax-speech = ["librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
|
1794 |
+
ftfy = ["ftfy"]
|
1795 |
+
integrations = ["optuna", "ray", "sigopt"]
|
1796 |
+
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "sudachipy (>=0.6.6)", "sudachidict-core (>=20220729)", "rhoknp (>=1.1.0)"]
|
1797 |
+
modelcreation = ["cookiecutter (==1.7.3)"]
|
1798 |
+
natten = ["natten (>=0.14.4)"]
|
1799 |
+
onnx = ["onnxconverter-common", "tf2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
1800 |
+
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
1801 |
+
optuna = ["optuna"]
|
1802 |
+
quality = ["black (>=23.1,<24.0)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (>=0.0.241)", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)"]
|
1803 |
+
ray = ["ray"]
|
1804 |
+
retrieval = ["faiss-cpu", "datasets (!=2.5.0)"]
|
1805 |
+
sagemaker = ["sagemaker (>=2.31.0)"]
|
1806 |
+
sentencepiece = ["sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)"]
|
1807 |
+
serving = ["pydantic", "uvicorn", "fastapi", "starlette"]
|
1808 |
+
sigopt = ["sigopt"]
|
1809 |
+
sklearn = ["scikit-learn"]
|
1810 |
+
speech = ["torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
|
1811 |
+
testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)"]
|
1812 |
+
tf = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)"]
|
1813 |
+
tf-cpu = ["tensorflow-cpu (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)"]
|
1814 |
+
tf-speech = ["librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
|
1815 |
+
timm = ["timm"]
|
1816 |
+
tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"]
|
1817 |
+
torch = ["torch (>=1.7,!=1.12.0)"]
|
1818 |
+
torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
|
1819 |
+
torch-vision = ["torchvision", "pillow"]
|
1820 |
+
torchhub = ["filelock", "huggingface-hub (>=0.11.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.2)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.7,!=1.12.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "tqdm (>=4.27)"]
|
1821 |
+
video = ["decord (==0.6.0)", "av (==9.2.0)"]
|
1822 |
+
vision = ["pillow"]
|
1823 |
+
|
1824 |
+
[[package]]
|
1825 |
+
name = "triton"
|
1826 |
+
version = "2.0.0"
|
1827 |
+
description = "A language and compiler for custom Deep Learning operations"
|
1828 |
+
category = "main"
|
1829 |
+
optional = false
|
1830 |
+
python-versions = "*"
|
1831 |
+
|
1832 |
+
[package.dependencies]
|
1833 |
+
cmake = "*"
|
1834 |
+
filelock = "*"
|
1835 |
+
lit = "*"
|
1836 |
+
torch = "*"
|
1837 |
+
|
1838 |
+
[package.extras]
|
1839 |
+
tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"]
|
1840 |
+
tutorials = ["matplotlib", "pandas", "tabulate"]
|
1841 |
+
|
1842 |
[[package]]
|
1843 |
name = "typing-extensions"
|
1844 |
version = "4.5.0"
|
|
|
1983 |
[metadata]
|
1984 |
lock-version = "1.1"
|
1985 |
python-versions = "^3.10"
|
1986 |
+
content-hash = "c748285bd150fadef69123d60f0b4ad96d99715916c7e1ab30214132749f8aed"
|
1987 |
|
1988 |
[metadata.files]
|
1989 |
altair = []
|
|
|
2032 |
cffi = []
|
2033 |
charset-normalizer = []
|
2034 |
click = []
|
2035 |
+
cmake = []
|
2036 |
colorama = []
|
2037 |
comm = []
|
2038 |
debugpy = []
|
|
|
2052 |
exceptiongroup = []
|
2053 |
executing = []
|
2054 |
fastjsonschema = []
|
2055 |
+
filelock = []
|
2056 |
fqdn = []
|
2057 |
gitdb = []
|
2058 |
gitpython = []
|
2059 |
+
huggingface-hub = []
|
2060 |
idna = []
|
2061 |
importlib-metadata = []
|
2062 |
iniconfig = []
|
|
|
2081 |
jupyter-server-terminals = []
|
2082 |
jupyterlab-pygments = []
|
2083 |
jupyterlab-widgets = []
|
2084 |
+
lit = []
|
2085 |
markdown-it-py = []
|
2086 |
markupsafe = []
|
2087 |
matplotlib-inline = []
|
2088 |
mdurl = []
|
2089 |
mistune = []
|
2090 |
+
mpmath = []
|
2091 |
nbclassic = []
|
2092 |
nbclient = []
|
2093 |
nbconvert = []
|
2094 |
nbformat = []
|
2095 |
nest-asyncio = []
|
2096 |
+
networkx = []
|
2097 |
notebook = []
|
2098 |
notebook-shim = []
|
2099 |
numpy = []
|
2100 |
+
nvidia-cublas-cu11 = []
|
2101 |
+
nvidia-cuda-cupti-cu11 = []
|
2102 |
+
nvidia-cuda-nvrtc-cu11 = []
|
2103 |
+
nvidia-cuda-runtime-cu11 = []
|
2104 |
+
nvidia-cudnn-cu11 = []
|
2105 |
+
nvidia-cufft-cu11 = []
|
2106 |
+
nvidia-curand-cu11 = []
|
2107 |
+
nvidia-cusolver-cu11 = []
|
2108 |
+
nvidia-cusparse-cu11 = []
|
2109 |
+
nvidia-nccl-cu11 = []
|
2110 |
+
nvidia-nvtx-cu11 = []
|
2111 |
packaging = []
|
2112 |
pandas = []
|
2113 |
pandocfilters = [
|
|
|
2200 |
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
|
2201 |
]
|
2202 |
pyzmq = []
|
2203 |
+
regex = []
|
2204 |
requests = []
|
2205 |
rfc3339-validator = []
|
2206 |
rfc3986-validator = []
|
|
|
2210 |
{file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
|
2211 |
{file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"},
|
2212 |
]
|
2213 |
+
sentencepiece = []
|
2214 |
six = [
|
2215 |
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
|
2216 |
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
|
|
2224 |
stack-data = []
|
2225 |
stmol = []
|
2226 |
streamlit = []
|
2227 |
+
sympy = []
|
2228 |
terminado = []
|
2229 |
tinycss2 = []
|
2230 |
+
tokenizers = []
|
2231 |
toml = [
|
2232 |
{file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
|
2233 |
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
|
|
2238 |
]
|
2239 |
tomlkit = []
|
2240 |
toolz = []
|
2241 |
+
torch = []
|
2242 |
tornado = []
|
2243 |
+
tqdm = []
|
2244 |
traitlets = []
|
2245 |
+
transformers = []
|
2246 |
+
triton = []
|
2247 |
typing-extensions = []
|
2248 |
tzdata = []
|
2249 |
tzlocal = []
|
protention/attention.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
from io import StringIO
|
2 |
from urllib import request
|
3 |
|
|
|
4 |
from Bio.PDB import PDBParser, Structure
|
|
|
5 |
|
6 |
|
7 |
def get_structure(pdb_code: str) -> Structure:
|
@@ -16,6 +18,21 @@ def get_structure(pdb_code: str) -> Structure:
|
|
16 |
return structure
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def get_attention(
|
20 |
pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
|
21 |
):
|
@@ -26,6 +43,7 @@ def get_attention(
|
|
26 |
structure = get_structure(pdb_code)
|
27 |
|
28 |
# get model
|
|
|
29 |
|
30 |
# call model
|
31 |
|
|
|
1 |
from io import StringIO
|
2 |
from urllib import request
|
3 |
|
4 |
+
import torch
|
5 |
from Bio.PDB import PDBParser, Structure
|
6 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
7 |
|
8 |
|
9 |
def get_structure(pdb_code: str) -> Structure:
|
|
|
18 |
return structure
|
19 |
|
20 |
|
21 |
+
def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
|
22 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
23 |
+
tokenizer = T5Tokenizer.from_pretrained(
|
24 |
+
"Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
|
25 |
+
)
|
26 |
+
|
27 |
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
|
28 |
+
device
|
29 |
+
)
|
30 |
+
|
31 |
+
model.full() if device == "cpu" else model.half()
|
32 |
+
|
33 |
+
return tokenizer, model
|
34 |
+
|
35 |
+
|
36 |
def get_attention(
|
37 |
pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
|
38 |
):
|
|
|
43 |
structure = get_structure(pdb_code)
|
44 |
|
45 |
# get model
|
46 |
+
tokenizer, model = get_protT5()
|
47 |
|
48 |
# call model
|
49 |
|
pyproject.toml
CHANGED
@@ -9,6 +9,9 @@ python = "^3.10"
|
|
9 |
streamlit = "^1.20.0"
|
10 |
stmol = "^0.0.9"
|
11 |
biopython = "^1.81"
|
|
|
|
|
|
|
12 |
|
13 |
[tool.poetry.dev-dependencies]
|
14 |
pytest = "^7.2.2"
|
|
|
9 |
streamlit = "^1.20.0"
|
10 |
stmol = "^0.0.9"
|
11 |
biopython = "^1.81"
|
12 |
+
transformers = "^4.27.1"
|
13 |
+
torch = "^2.0.0"
|
14 |
+
sentencepiece = "^0.1.97"
|
15 |
|
16 |
[tool.poetry.dev-dependencies]
|
17 |
pytest = "^7.2.2"
|
tests/test_attention.py
CHANGED
@@ -1,10 +1,24 @@
|
|
1 |
from Bio.PDB.Structure import Structure
|
|
|
2 |
|
3 |
-
from protention.attention import get_structure
|
4 |
|
5 |
|
6 |
def test_get_structure():
|
7 |
pdb_id = "1AKE"
|
8 |
structure = get_structure(pdb_id)
|
|
|
9 |
assert structure is not None
|
10 |
assert isinstance(structure, Structure)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from Bio.PDB.Structure import Structure
|
2 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
3 |
|
4 |
+
from protention.attention import get_protT5, get_structure
|
5 |
|
6 |
|
7 |
def test_get_structure():
|
8 |
pdb_id = "1AKE"
|
9 |
structure = get_structure(pdb_id)
|
10 |
+
|
11 |
assert structure is not None
|
12 |
assert isinstance(structure, Structure)
|
13 |
+
|
14 |
+
|
15 |
+
def test_get_protT5():
|
16 |
+
result = get_protT5()
|
17 |
+
|
18 |
+
assert result is not None
|
19 |
+
assert isinstance(result, tuple)
|
20 |
+
|
21 |
+
tokenizer, model = result
|
22 |
+
|
23 |
+
assert isinstance(tokenizer, T5Tokenizer)
|
24 |
+
assert isinstance(model, T5EncoderModel)
|