aksell commited on
Commit
87c0dbc
1 Parent(s): f12036e

Add get_protT5

Browse files

Sentencepiece is needed for the T5 tokenizer

Files changed (4) hide show
  1. poetry.lock +370 -1
  2. protention/attention.py +18 -0
  3. pyproject.toml +3 -0
  4. tests/test_attention.py +15 -1
poetry.lock CHANGED
@@ -217,6 +217,17 @@ python-versions = ">=3.7"
217
  [package.dependencies]
218
  colorama = {version = "*", markers = "platform_system == \"Windows\""}
219
 
 
 
 
 
 
 
 
 
 
 
 
220
  [[package]]
221
  name = "colorama"
222
  version = "0.4.6"
@@ -315,6 +326,18 @@ python-versions = "*"
315
  [package.extras]
316
  devel = ["colorama", "jsonschema", "json-spec", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"]
317
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  [[package]]
319
  name = "fqdn"
320
  version = "1.5.1"
@@ -345,6 +368,33 @@ python-versions = ">=3.7"
345
  [package.dependencies]
346
  gitdb = ">=4.0.1,<5"
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  [[package]]
349
  name = "idna"
350
  version = "3.4"
@@ -691,6 +741,14 @@ category = "main"
691
  optional = false
692
  python-versions = ">=3.7"
693
 
 
 
 
 
 
 
 
 
694
  [[package]]
695
  name = "markdown-it-py"
696
  version = "2.2.0"
@@ -747,6 +805,20 @@ category = "main"
747
  optional = false
748
  python-versions = "*"
749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
  [[package]]
751
  name = "nbclassic"
752
  version = "0.5.3"
@@ -858,6 +930,21 @@ category = "main"
858
  optional = false
859
  python-versions = ">=3.5"
860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
  [[package]]
862
  name = "notebook"
863
  version = "6.5.3"
@@ -911,6 +998,94 @@ category = "main"
911
  optional = false
912
  python-versions = ">=3.8"
913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
914
  [[package]]
915
  name = "packaging"
916
  version = "23.0"
@@ -1239,6 +1414,14 @@ python-versions = ">=3.6"
1239
  [package.dependencies]
1240
  cffi = {version = "*", markers = "implementation_name == \"pypy\""}
1241
 
 
 
 
 
 
 
 
 
1242
  [[package]]
1243
  name = "requests"
1244
  version = "2.28.2"
@@ -1312,6 +1495,14 @@ nativelib = ["pyobjc-framework-cocoa", "pywin32"]
1312
  objc = ["pyobjc-framework-cocoa"]
1313
  win32 = ["pywin32"]
1314
 
 
 
 
 
 
 
 
 
1315
  [[package]]
1316
  name = "six"
1317
  version = "1.16.0"
@@ -1411,6 +1602,17 @@ watchdog = {version = "*", markers = "platform_system != \"Darwin\""}
1411
  [package.extras]
1412
  snowflake = ["snowflake-snowpark-python"]
1413
 
 
 
 
 
 
 
 
 
 
 
 
1414
  [[package]]
1415
  name = "terminado"
1416
  version = "0.17.1"
@@ -1443,6 +1645,19 @@ webencodings = ">=0.4"
1443
  doc = ["sphinx", "sphinx-rtd-theme"]
1444
  test = ["pytest", "isort", "flake8"]
1445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1446
  [[package]]
1447
  name = "toml"
1448
  version = "0.10.2"
@@ -1475,6 +1690,36 @@ category = "main"
1475
  optional = false
1476
  python-versions = ">=3.5"
1477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1478
  [[package]]
1479
  name = "tornado"
1480
  version = "6.2"
@@ -1483,6 +1728,23 @@ category = "main"
1483
  optional = false
1484
  python-versions = ">= 3.7"
1485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1486
  [[package]]
1487
  name = "traitlets"
1488
  version = "5.9.0"
@@ -1495,6 +1757,88 @@ python-versions = ">=3.7"
1495
  docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
1496
  test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"]
1497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1498
  [[package]]
1499
  name = "typing-extensions"
1500
  version = "4.5.0"
@@ -1639,7 +1983,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
1639
  [metadata]
1640
  lock-version = "1.1"
1641
  python-versions = "^3.10"
1642
- content-hash = "1e79d688b56335b1eafcb169572e0b8983eff0cb2da5ece8807ae02316f25f12"
1643
 
1644
  [metadata.files]
1645
  altair = []
@@ -1688,6 +2032,7 @@ certifi = []
1688
  cffi = []
1689
  charset-normalizer = []
1690
  click = []
 
1691
  colorama = []
1692
  comm = []
1693
  debugpy = []
@@ -1707,9 +2052,11 @@ entrypoints = [
1707
  exceptiongroup = []
1708
  executing = []
1709
  fastjsonschema = []
 
1710
  fqdn = []
1711
  gitdb = []
1712
  gitpython = []
 
1713
  idna = []
1714
  importlib-metadata = []
1715
  iniconfig = []
@@ -1734,19 +2081,33 @@ jupyter-server = []
1734
  jupyter-server-terminals = []
1735
  jupyterlab-pygments = []
1736
  jupyterlab-widgets = []
 
1737
  markdown-it-py = []
1738
  markupsafe = []
1739
  matplotlib-inline = []
1740
  mdurl = []
1741
  mistune = []
 
1742
  nbclassic = []
1743
  nbclient = []
1744
  nbconvert = []
1745
  nbformat = []
1746
  nest-asyncio = []
 
1747
  notebook = []
1748
  notebook-shim = []
1749
  numpy = []
 
 
 
 
 
 
 
 
 
 
 
1750
  packaging = []
1751
  pandas = []
1752
  pandocfilters = [
@@ -1839,6 +2200,7 @@ pyyaml = [
1839
  {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
1840
  ]
1841
  pyzmq = []
 
1842
  requests = []
1843
  rfc3339-validator = []
1844
  rfc3986-validator = []
@@ -1848,6 +2210,7 @@ send2trash = [
1848
  {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
1849
  {file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"},
1850
  ]
 
1851
  six = [
1852
  {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
1853
  {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
@@ -1861,8 +2224,10 @@ soupsieve = []
1861
  stack-data = []
1862
  stmol = []
1863
  streamlit = []
 
1864
  terminado = []
1865
  tinycss2 = []
 
1866
  toml = [
1867
  {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
1868
  {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
@@ -1873,8 +2238,12 @@ tomli = [
1873
  ]
1874
  tomlkit = []
1875
  toolz = []
 
1876
  tornado = []
 
1877
  traitlets = []
 
 
1878
  typing-extensions = []
1879
  tzdata = []
1880
  tzlocal = []
 
217
  [package.dependencies]
218
  colorama = {version = "*", markers = "platform_system == \"Windows\""}
219
 
220
+ [[package]]
221
+ name = "cmake"
222
+ version = "3.26.0"
223
+ description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
224
+ category = "main"
225
+ optional = false
226
+ python-versions = "*"
227
+
228
+ [package.extras]
229
+ test = ["codecov (>=2.0.5)", "coverage (>=4.2)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pytest (>=3.0.3)", "pytest-cov (>=2.4.0)", "pytest-runner (>=2.9)", "pytest-virtualenv (>=1.7.0)", "scikit-build (>=0.10.0)", "setuptools (>=28.0.0)", "virtualenv (>=15.0.3)", "wheel"]
230
+
231
  [[package]]
232
  name = "colorama"
233
  version = "0.4.6"
 
326
  [package.extras]
327
  devel = ["colorama", "jsonschema", "json-spec", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"]
328
 
329
+ [[package]]
330
+ name = "filelock"
331
+ version = "3.10.0"
332
+ description = "A platform independent file lock."
333
+ category = "main"
334
+ optional = false
335
+ python-versions = ">=3.7"
336
+
337
+ [package.extras]
338
+ docs = ["furo (>=2022.12.7)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)", "sphinx (>=6.1.3)"]
339
+ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.1)", "pytest-cov (>=4)", "pytest-timeout (>=2.1)", "pytest (>=7.2.2)"]
340
+
341
  [[package]]
342
  name = "fqdn"
343
  version = "1.5.1"
 
368
  [package.dependencies]
369
  gitdb = ">=4.0.1,<5"
370
 
371
+ [[package]]
372
+ name = "huggingface-hub"
373
+ version = "0.13.2"
374
+ description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
375
+ category = "main"
376
+ optional = false
377
+ python-versions = ">=3.7.0"
378
+
379
+ [package.dependencies]
380
+ filelock = "*"
381
+ packaging = ">=20.9"
382
+ pyyaml = ">=5.1"
383
+ requests = "*"
384
+ tqdm = ">=4.42.1"
385
+ typing-extensions = ">=3.7.4.3"
386
+
387
+ [package.extras]
388
+ all = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
389
+ cli = ["InquirerPy (==0.3.4)"]
390
+ dev = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
391
+ fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
392
+ quality = ["black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)"]
393
+ tensorflow = ["tensorflow", "pydot", "graphviz"]
394
+ testing = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow"]
395
+ torch = ["torch"]
396
+ typing = ["types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
397
+
398
  [[package]]
399
  name = "idna"
400
  version = "3.4"
 
741
  optional = false
742
  python-versions = ">=3.7"
743
 
744
+ [[package]]
745
+ name = "lit"
746
+ version = "15.0.7"
747
+ description = "A Software Testing Tool"
748
+ category = "main"
749
+ optional = false
750
+ python-versions = "*"
751
+
752
  [[package]]
753
  name = "markdown-it-py"
754
  version = "2.2.0"
 
805
  optional = false
806
  python-versions = "*"
807
 
808
+ [[package]]
809
+ name = "mpmath"
810
+ version = "1.3.0"
811
+ description = "Python library for arbitrary-precision floating-point arithmetic"
812
+ category = "main"
813
+ optional = false
814
+ python-versions = "*"
815
+
816
+ [package.extras]
817
+ develop = ["pytest (>=4.6)", "pycodestyle", "pytest-cov", "codecov", "wheel"]
818
+ docs = ["sphinx"]
819
+ gmpy = ["gmpy2 (>=2.1.0a4)"]
820
+ tests = ["pytest (>=4.6)"]
821
+
822
  [[package]]
823
  name = "nbclassic"
824
  version = "0.5.3"
 
930
  optional = false
931
  python-versions = ">=3.5"
932
 
933
+ [[package]]
934
+ name = "networkx"
935
+ version = "3.0"
936
+ description = "Python package for creating and manipulating graphs and networks"
937
+ category = "main"
938
+ optional = false
939
+ python-versions = ">=3.8"
940
+
941
+ [package.extras]
942
+ default = ["numpy (>=1.20)", "scipy (>=1.8)", "matplotlib (>=3.4)", "pandas (>=1.3)"]
943
+ developer = ["pre-commit (>=2.20)", "mypy (>=0.991)"]
944
+ doc = ["sphinx (==5.2.3)", "pydata-sphinx-theme (>=0.11)", "sphinx-gallery (>=0.11)", "numpydoc (>=1.5)", "pillow (>=9.2)", "nb2plots (>=0.6)", "texext (>=0.6.7)"]
945
+ extra = ["lxml (>=4.6)", "pygraphviz (>=1.10)", "pydot (>=1.4.2)", "sympy (>=1.10)"]
946
+ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)", "codecov (>=2.1)"]
947
+
948
  [[package]]
949
  name = "notebook"
950
  version = "6.5.3"
 
998
  optional = false
999
  python-versions = ">=3.8"
1000
 
1001
+ [[package]]
1002
+ name = "nvidia-cublas-cu11"
1003
+ version = "11.10.3.66"
1004
+ description = "CUBLAS native runtime libraries"
1005
+ category = "main"
1006
+ optional = false
1007
+ python-versions = ">=3"
1008
+
1009
+ [[package]]
1010
+ name = "nvidia-cuda-cupti-cu11"
1011
+ version = "11.7.101"
1012
+ description = "CUDA profiling tools runtime libs."
1013
+ category = "main"
1014
+ optional = false
1015
+ python-versions = ">=3"
1016
+
1017
+ [[package]]
1018
+ name = "nvidia-cuda-nvrtc-cu11"
1019
+ version = "11.7.99"
1020
+ description = "NVRTC native runtime libraries"
1021
+ category = "main"
1022
+ optional = false
1023
+ python-versions = ">=3"
1024
+
1025
+ [[package]]
1026
+ name = "nvidia-cuda-runtime-cu11"
1027
+ version = "11.7.99"
1028
+ description = "CUDA Runtime native Libraries"
1029
+ category = "main"
1030
+ optional = false
1031
+ python-versions = ">=3"
1032
+
1033
+ [[package]]
1034
+ name = "nvidia-cudnn-cu11"
1035
+ version = "8.5.0.96"
1036
+ description = "cuDNN runtime libraries"
1037
+ category = "main"
1038
+ optional = false
1039
+ python-versions = ">=3"
1040
+
1041
+ [[package]]
1042
+ name = "nvidia-cufft-cu11"
1043
+ version = "10.9.0.58"
1044
+ description = "CUFFT native runtime libraries"
1045
+ category = "main"
1046
+ optional = false
1047
+ python-versions = ">=3"
1048
+
1049
+ [[package]]
1050
+ name = "nvidia-curand-cu11"
1051
+ version = "10.2.10.91"
1052
+ description = "CURAND native runtime libraries"
1053
+ category = "main"
1054
+ optional = false
1055
+ python-versions = ">=3"
1056
+
1057
+ [[package]]
1058
+ name = "nvidia-cusolver-cu11"
1059
+ version = "11.4.0.1"
1060
+ description = "CUDA solver native runtime libraries"
1061
+ category = "main"
1062
+ optional = false
1063
+ python-versions = ">=3"
1064
+
1065
+ [[package]]
1066
+ name = "nvidia-cusparse-cu11"
1067
+ version = "11.7.4.91"
1068
+ description = "CUSPARSE native runtime libraries"
1069
+ category = "main"
1070
+ optional = false
1071
+ python-versions = ">=3"
1072
+
1073
+ [[package]]
1074
+ name = "nvidia-nccl-cu11"
1075
+ version = "2.14.3"
1076
+ description = "NVIDIA Collective Communication Library (NCCL) Runtime"
1077
+ category = "main"
1078
+ optional = false
1079
+ python-versions = ">=3"
1080
+
1081
+ [[package]]
1082
+ name = "nvidia-nvtx-cu11"
1083
+ version = "11.7.91"
1084
+ description = "NVIDIA Tools Extension"
1085
+ category = "main"
1086
+ optional = false
1087
+ python-versions = ">=3"
1088
+
1089
  [[package]]
1090
  name = "packaging"
1091
  version = "23.0"
 
1414
  [package.dependencies]
1415
  cffi = {version = "*", markers = "implementation_name == \"pypy\""}
1416
 
1417
+ [[package]]
1418
+ name = "regex"
1419
+ version = "2022.10.31"
1420
+ description = "Alternative regular expression module, to replace re."
1421
+ category = "main"
1422
+ optional = false
1423
+ python-versions = ">=3.6"
1424
+
1425
  [[package]]
1426
  name = "requests"
1427
  version = "2.28.2"
 
1495
  objc = ["pyobjc-framework-cocoa"]
1496
  win32 = ["pywin32"]
1497
 
1498
+ [[package]]
1499
+ name = "sentencepiece"
1500
+ version = "0.1.97"
1501
+ description = "SentencePiece python wrapper"
1502
+ category = "main"
1503
+ optional = false
1504
+ python-versions = "*"
1505
+
1506
  [[package]]
1507
  name = "six"
1508
  version = "1.16.0"
 
1602
  [package.extras]
1603
  snowflake = ["snowflake-snowpark-python"]
1604
 
1605
+ [[package]]
1606
+ name = "sympy"
1607
+ version = "1.11.1"
1608
+ description = "Computer algebra system (CAS) in Python"
1609
+ category = "main"
1610
+ optional = false
1611
+ python-versions = ">=3.8"
1612
+
1613
+ [package.dependencies]
1614
+ mpmath = ">=0.19"
1615
+
1616
  [[package]]
1617
  name = "terminado"
1618
  version = "0.17.1"
 
1645
  doc = ["sphinx", "sphinx-rtd-theme"]
1646
  test = ["pytest", "isort", "flake8"]
1647
 
1648
+ [[package]]
1649
+ name = "tokenizers"
1650
+ version = "0.13.2"
1651
+ description = "Fast and Customizable Tokenizers"
1652
+ category = "main"
1653
+ optional = false
1654
+ python-versions = "*"
1655
+
1656
+ [package.extras]
1657
+ dev = ["pytest", "requests", "numpy", "datasets", "black (==22.3)"]
1658
+ docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
1659
+ testing = ["pytest", "requests", "numpy", "datasets", "black (==22.3)"]
1660
+
1661
  [[package]]
1662
  name = "toml"
1663
  version = "0.10.2"
 
1690
  optional = false
1691
  python-versions = ">=3.5"
1692
 
1693
+ [[package]]
1694
+ name = "torch"
1695
+ version = "2.0.0"
1696
+ description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
1697
+ category = "main"
1698
+ optional = false
1699
+ python-versions = ">=3.8.0"
1700
+
1701
+ [package.dependencies]
1702
+ filelock = "*"
1703
+ jinja2 = "*"
1704
+ networkx = "*"
1705
+ nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1706
+ nvidia-cuda-cupti-cu11 = {version = "11.7.101", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1707
+ nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1708
+ nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1709
+ nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1710
+ nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1711
+ nvidia-curand-cu11 = {version = "10.2.10.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1712
+ nvidia-cusolver-cu11 = {version = "11.4.0.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1713
+ nvidia-cusparse-cu11 = {version = "11.7.4.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1714
+ nvidia-nccl-cu11 = {version = "2.14.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1715
+ nvidia-nvtx-cu11 = {version = "11.7.91", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1716
+ sympy = "*"
1717
+ triton = {version = "2.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
1718
+ typing-extensions = "*"
1719
+
1720
+ [package.extras]
1721
+ opt-einsum = ["opt-einsum (>=3.3)"]
1722
+
1723
  [[package]]
1724
  name = "tornado"
1725
  version = "6.2"
 
1728
  optional = false
1729
  python-versions = ">= 3.7"
1730
 
1731
+ [[package]]
1732
+ name = "tqdm"
1733
+ version = "4.65.0"
1734
+ description = "Fast, Extensible Progress Meter"
1735
+ category = "main"
1736
+ optional = false
1737
+ python-versions = ">=3.7"
1738
+
1739
+ [package.dependencies]
1740
+ colorama = {version = "*", markers = "platform_system == \"Windows\""}
1741
+
1742
+ [package.extras]
1743
+ dev = ["py-make (>=0.1.0)", "twine", "wheel"]
1744
+ notebook = ["ipywidgets (>=6)"]
1745
+ slack = ["slack-sdk"]
1746
+ telegram = ["requests"]
1747
+
1748
  [[package]]
1749
  name = "traitlets"
1750
  version = "5.9.0"
 
1757
  docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
1758
  test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"]
1759
 
1760
+ [[package]]
1761
+ name = "transformers"
1762
+ version = "4.27.1"
1763
+ description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
1764
+ category = "main"
1765
+ optional = false
1766
+ python-versions = ">=3.7.0"
1767
+
1768
+ [package.dependencies]
1769
+ filelock = "*"
1770
+ huggingface-hub = ">=0.11.0,<1.0"
1771
+ numpy = ">=1.17"
1772
+ packaging = ">=20.0"
1773
+ pyyaml = ">=5.1"
1774
+ regex = "!=2019.12.17"
1775
+ requests = "*"
1776
+ tokenizers = ">=0.11.1,<0.11.3 || >0.11.3,<0.14"
1777
+ tqdm = ">=4.27"
1778
+
1779
+ [package.extras]
1780
+ accelerate = ["accelerate (>=0.10.0)"]
1781
+ all = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "torch (>=1.7,!=1.12.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "accelerate (>=0.10.0)", "decord (==0.6.0)", "av (==9.2.0)"]
1782
+ audio = ["librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
1783
+ codecarbon = ["codecarbon (==1.2.0)"]
1784
+ deepspeed = ["deepspeed (>=0.6.5)", "accelerate (>=0.10.0)"]
1785
+ deepspeed-testing = ["deepspeed (>=0.6.5)", "accelerate (>=0.10.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "optuna", "sentencepiece (>=0.1.91,!=0.1.92)"]
1786
+ dev = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "torch (>=1.7,!=1.12.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "accelerate (>=0.10.0)", "decord (==0.6.0)", "av (==9.2.0)", "pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "isort (>=5.5.4)", "ruff (>=0.0.241)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "sudachipy (>=0.6.6)", "sudachidict-core (>=20220729)", "rhoknp (>=1.1.0)", "hf-doc-builder", "scikit-learn"]
1787
+ dev-tensorflow = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "pillow", "isort (>=5.5.4)", "ruff (>=0.0.241)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
1788
+ dev-torch = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)", "torch (>=1.7,!=1.12.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "isort (>=5.5.4)", "ruff (>=0.0.241)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "sudachipy (>=0.6.6)", "sudachidict-core (>=20220729)", "rhoknp (>=1.1.0)", "hf-doc-builder", "scikit-learn", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
1789
+ docs = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)", "torch (>=1.7,!=1.12.0)", "jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)", "sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm", "pillow", "optuna", "ray", "sigopt", "timm", "torchvision", "codecarbon (==1.2.0)", "accelerate (>=0.10.0)", "decord (==0.6.0)", "av (==9.2.0)", "hf-doc-builder"]
1790
+ docs_specific = ["hf-doc-builder"]
1791
+ fairscale = ["fairscale (>0.3)"]
1792
+ flax = ["jax (>=0.2.8,!=0.3.2,<=0.3.6)", "jaxlib (>=0.1.65,<=0.3.6)", "flax (>=0.4.1)", "optax (>=0.0.8)"]
1793
+ flax-speech = ["librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
1794
+ ftfy = ["ftfy"]
1795
+ integrations = ["optuna", "ray", "sigopt"]
1796
+ ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "unidic-lite (>=1.0.7)", "unidic (>=1.0.2)", "sudachipy (>=0.6.6)", "sudachidict-core (>=20220729)", "rhoknp (>=1.1.0)"]
1797
+ modelcreation = ["cookiecutter (==1.7.3)"]
1798
+ natten = ["natten (>=0.14.4)"]
1799
+ onnx = ["onnxconverter-common", "tf2onnx", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
1800
+ onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
1801
+ optuna = ["optuna"]
1802
+ quality = ["black (>=23.1,<24.0)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "ruff (>=0.0.241)", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)"]
1803
+ ray = ["ray"]
1804
+ retrieval = ["faiss-cpu", "datasets (!=2.5.0)"]
1805
+ sagemaker = ["sagemaker (>=2.31.0)"]
1806
+ sentencepiece = ["sentencepiece (>=0.1.91,!=0.1.92)", "protobuf (<=3.20.2)"]
1807
+ serving = ["pydantic", "uvicorn", "fastapi", "starlette"]
1808
+ sigopt = ["sigopt"]
1809
+ sklearn = ["scikit-learn"]
1810
+ speech = ["torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
1811
+ testing = ["pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "pytest-timeout", "black (>=23.1,<24.0)", "sacrebleu (>=1.4.12,<2.0.0)", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "nltk", "GitPython (<3.1.19)", "hf-doc-builder (>=0.3.0)", "protobuf (<=3.20.2)", "sacremoses", "rjieba", "safetensors (>=0.2.1)", "beautifulsoup4", "faiss-cpu", "cookiecutter (==1.7.3)"]
1812
+ tf = ["tensorflow (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)"]
1813
+ tf-cpu = ["tensorflow-cpu (>=2.4,<2.12)", "onnxconverter-common", "tf2onnx", "tensorflow-text", "keras-nlp (>=0.3.1)"]
1814
+ tf-speech = ["librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
1815
+ timm = ["timm"]
1816
+ tokenizers = ["tokenizers (>=0.11.1,!=0.11.3,<0.14)"]
1817
+ torch = ["torch (>=1.7,!=1.12.0)"]
1818
+ torch-speech = ["torchaudio", "librosa", "pyctcdecode (>=0.4.0)", "phonemizer", "kenlm"]
1819
+ torch-vision = ["torchvision", "pillow"]
1820
+ torchhub = ["filelock", "huggingface-hub (>=0.11.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf (<=3.20.2)", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "torch (>=1.7,!=1.12.0)", "tokenizers (>=0.11.1,!=0.11.3,<0.14)", "tqdm (>=4.27)"]
1821
+ video = ["decord (==0.6.0)", "av (==9.2.0)"]
1822
+ vision = ["pillow"]
1823
+
1824
+ [[package]]
1825
+ name = "triton"
1826
+ version = "2.0.0"
1827
+ description = "A language and compiler for custom Deep Learning operations"
1828
+ category = "main"
1829
+ optional = false
1830
+ python-versions = "*"
1831
+
1832
+ [package.dependencies]
1833
+ cmake = "*"
1834
+ filelock = "*"
1835
+ lit = "*"
1836
+ torch = "*"
1837
+
1838
+ [package.extras]
1839
+ tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"]
1840
+ tutorials = ["matplotlib", "pandas", "tabulate"]
1841
+
1842
  [[package]]
1843
  name = "typing-extensions"
1844
  version = "4.5.0"
 
1983
  [metadata]
1984
  lock-version = "1.1"
1985
  python-versions = "^3.10"
1986
+ content-hash = "c748285bd150fadef69123d60f0b4ad96d99715916c7e1ab30214132749f8aed"
1987
 
1988
  [metadata.files]
1989
  altair = []
 
2032
  cffi = []
2033
  charset-normalizer = []
2034
  click = []
2035
+ cmake = []
2036
  colorama = []
2037
  comm = []
2038
  debugpy = []
 
2052
  exceptiongroup = []
2053
  executing = []
2054
  fastjsonschema = []
2055
+ filelock = []
2056
  fqdn = []
2057
  gitdb = []
2058
  gitpython = []
2059
+ huggingface-hub = []
2060
  idna = []
2061
  importlib-metadata = []
2062
  iniconfig = []
 
2081
  jupyter-server-terminals = []
2082
  jupyterlab-pygments = []
2083
  jupyterlab-widgets = []
2084
+ lit = []
2085
  markdown-it-py = []
2086
  markupsafe = []
2087
  matplotlib-inline = []
2088
  mdurl = []
2089
  mistune = []
2090
+ mpmath = []
2091
  nbclassic = []
2092
  nbclient = []
2093
  nbconvert = []
2094
  nbformat = []
2095
  nest-asyncio = []
2096
+ networkx = []
2097
  notebook = []
2098
  notebook-shim = []
2099
  numpy = []
2100
+ nvidia-cublas-cu11 = []
2101
+ nvidia-cuda-cupti-cu11 = []
2102
+ nvidia-cuda-nvrtc-cu11 = []
2103
+ nvidia-cuda-runtime-cu11 = []
2104
+ nvidia-cudnn-cu11 = []
2105
+ nvidia-cufft-cu11 = []
2106
+ nvidia-curand-cu11 = []
2107
+ nvidia-cusolver-cu11 = []
2108
+ nvidia-cusparse-cu11 = []
2109
+ nvidia-nccl-cu11 = []
2110
+ nvidia-nvtx-cu11 = []
2111
  packaging = []
2112
  pandas = []
2113
  pandocfilters = [
 
2200
  {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
2201
  ]
2202
  pyzmq = []
2203
+ regex = []
2204
  requests = []
2205
  rfc3339-validator = []
2206
  rfc3986-validator = []
 
2210
  {file = "Send2Trash-1.8.0-py3-none-any.whl", hash = "sha256:f20eaadfdb517eaca5ce077640cb261c7d2698385a6a0f072a4a5447fd49fa08"},
2211
  {file = "Send2Trash-1.8.0.tar.gz", hash = "sha256:d2c24762fd3759860a0aff155e45871447ea58d2be6bdd39b5c8f966a0c99c2d"},
2212
  ]
2213
+ sentencepiece = []
2214
  six = [
2215
  {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
2216
  {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
 
2224
  stack-data = []
2225
  stmol = []
2226
  streamlit = []
2227
+ sympy = []
2228
  terminado = []
2229
  tinycss2 = []
2230
+ tokenizers = []
2231
  toml = [
2232
  {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
2233
  {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
 
2238
  ]
2239
  tomlkit = []
2240
  toolz = []
2241
+ torch = []
2242
  tornado = []
2243
+ tqdm = []
2244
  traitlets = []
2245
+ transformers = []
2246
+ triton = []
2247
  typing-extensions = []
2248
  tzdata = []
2249
  tzlocal = []
protention/attention.py CHANGED
@@ -1,7 +1,9 @@
1
  from io import StringIO
2
  from urllib import request
3
 
 
4
  from Bio.PDB import PDBParser, Structure
 
5
 
6
 
7
  def get_structure(pdb_code: str) -> Structure:
@@ -16,6 +18,21 @@ def get_structure(pdb_code: str) -> Structure:
16
  return structure
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_attention(
20
  pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
21
  ):
@@ -26,6 +43,7 @@ def get_attention(
26
  structure = get_structure(pdb_code)
27
 
28
  # get model
 
29
 
30
  # call model
31
 
 
1
  from io import StringIO
2
  from urllib import request
3
 
4
+ import torch
5
  from Bio.PDB import PDBParser, Structure
6
+ from transformers import T5EncoderModel, T5Tokenizer
7
 
8
 
9
  def get_structure(pdb_code: str) -> Structure:
 
18
  return structure
19
 
20
 
21
+ def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ tokenizer = T5Tokenizer.from_pretrained(
24
+ "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
25
+ )
26
+
27
+ model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(
28
+ device
29
+ )
30
+
31
+ model.full() if device == "cpu" else model.half()
32
+
33
+ return tokenizer, model
34
+
35
+
36
  def get_attention(
37
  pdb_code: str, chain_ids: list[str], layer: int, head: int, min_attn: float = 0.2
38
  ):
 
43
  structure = get_structure(pdb_code)
44
 
45
  # get model
46
+ tokenizer, model = get_protT5()
47
 
48
  # call model
49
 
pyproject.toml CHANGED
@@ -9,6 +9,9 @@ python = "^3.10"
9
  streamlit = "^1.20.0"
10
  stmol = "^0.0.9"
11
  biopython = "^1.81"
 
 
 
12
 
13
  [tool.poetry.dev-dependencies]
14
  pytest = "^7.2.2"
 
9
  streamlit = "^1.20.0"
10
  stmol = "^0.0.9"
11
  biopython = "^1.81"
12
+ transformers = "^4.27.1"
13
+ torch = "^2.0.0"
14
+ sentencepiece = "^0.1.97"
15
 
16
  [tool.poetry.dev-dependencies]
17
  pytest = "^7.2.2"
tests/test_attention.py CHANGED
@@ -1,10 +1,24 @@
1
  from Bio.PDB.Structure import Structure
 
2
 
3
- from protention.attention import get_structure
4
 
5
 
6
  def test_get_structure():
7
  pdb_id = "1AKE"
8
  structure = get_structure(pdb_id)
 
9
  assert structure is not None
10
  assert isinstance(structure, Structure)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from Bio.PDB.Structure import Structure
2
+ from transformers import T5EncoderModel, T5Tokenizer
3
 
4
+ from protention.attention import get_protT5, get_structure
5
 
6
 
7
  def test_get_structure():
8
  pdb_id = "1AKE"
9
  structure = get_structure(pdb_id)
10
+
11
  assert structure is not None
12
  assert isinstance(structure, Structure)
13
+
14
+
15
+ def test_get_protT5():
16
+ result = get_protT5()
17
+
18
+ assert result is not None
19
+ assert isinstance(result, tuple)
20
+
21
+ tokenizer, model = result
22
+
23
+ assert isinstance(tokenizer, T5Tokenizer)
24
+ assert isinstance(model, T5EncoderModel)