aksell commited on
Commit
9938060
1 Parent(s): b7bfd49

Use python 3.8.9

Browse files

Spaces runs on Python 3.8.9

Files changed (4) hide show
  1. .python-version +1 -0
  2. hexviz/attention.py +18 -17
  3. poetry.lock +47 -2
  4. pyproject.toml +1 -1
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.8.9
hexviz/attention.py CHANGED
@@ -1,5 +1,6 @@
1
  from enum import Enum
2
  from io import StringIO
 
3
  from urllib import request
4
 
5
  import streamlit as st
@@ -32,7 +33,7 @@ def get_structure(pdb_code: str) -> Structure:
32
  structure = parser.get_structure(pdb_code, file)
33
  return structure
34
 
35
- def get_sequences(structure: Structure) -> list[str]:
36
  """
37
  Get list of sequences with residues on a single letter format
38
 
@@ -47,7 +48,7 @@ def get_sequences(structure: Structure) -> list[str]:
47
  sequences.append(list(residues_single_letter))
48
  return sequences
49
 
50
- def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
51
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52
  tokenizer = T5Tokenizer.from_pretrained(
53
  "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
@@ -61,32 +62,32 @@ def get_protT5() -> tuple[T5Tokenizer, T5EncoderModel]:
61
 
62
  return tokenizer, model
63
 
64
- def get_tape_bert() -> tuple[TAPETokenizer, ProteinBertModel]:
65
  tokenizer = TAPETokenizer()
66
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
67
  return tokenizer, model
68
 
69
  @st.cache_data
70
  def get_attention(
71
- sequence: list[str], model_type: ModelType = ModelType.TAPE_BERT
72
  ):
73
- match model_type:
74
- case ModelType.TAPE_BERT:
75
- tokenizer, model = get_tape_bert()
76
- token_idxs = tokenizer.encode(sequence).tolist()
77
- inputs = torch.tensor(token_idxs).unsqueeze(0)
78
- with torch.no_grad():
79
- attns = model(inputs)[-1]
80
- # Remove attention from <CLS> (first) and <SEP> (last) token
81
- attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
82
- attns = torch.stack([attn.squeeze(0) for attn in attns])
83
- case ModelType.PROT_T5:
84
  attns = None
85
  # Space separate sequences
86
  sequences = [" ".join(sequence) for sequence in sequences]
87
  tokenizer, model = get_protT5()
88
- case _:
89
- raise ValueError(f"Model {model_type} not supported")
 
90
  return attns
91
 
92
  def unidirectional_sum_filtered(attention, layer, head, threshold):
 
1
  from enum import Enum
2
  from io import StringIO
3
+ from typing import List, Tuple
4
  from urllib import request
5
 
6
  import streamlit as st
 
33
  structure = parser.get_structure(pdb_code, file)
34
  return structure
35
 
36
+ def get_sequences(structure: Structure) -> List[str]:
37
  """
38
  Get list of sequences with residues on a single letter format
39
 
 
48
  sequences.append(list(residues_single_letter))
49
  return sequences
50
 
51
+ def get_protT5() -> Tuple[T5Tokenizer, T5EncoderModel]:
52
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
53
  tokenizer = T5Tokenizer.from_pretrained(
54
  "Rostlab/prot_t5_xl_half_uniref50-enc", do_lower_case=False
 
62
 
63
  return tokenizer, model
64
 
65
+ def get_tape_bert() -> Tuple[TAPETokenizer, ProteinBertModel]:
66
  tokenizer = TAPETokenizer()
67
  model = ProteinBertModel.from_pretrained('bert-base', output_attentions=True)
68
  return tokenizer, model
69
 
70
  @st.cache_data
71
  def get_attention(
72
+ sequence: List[str], model_type: ModelType = ModelType.TAPE_BERT
73
  ):
74
+ if model_type == ModelType.TAPE_BERT:
75
+ tokenizer, model = get_tape_bert()
76
+ token_idxs = tokenizer.encode(sequence).tolist()
77
+ inputs = torch.tensor(token_idxs).unsqueeze(0)
78
+ with torch.no_grad():
79
+ attns = model(inputs)[-1]
80
+ # Remove attention from <CLS> (first) and <SEP> (last) token
81
+ attns = [attn[:, :, 1:-1, 1:-1] for attn in attns]
82
+ attns = torch.stack([attn.squeeze(0) for attn in attns])
83
+ elif model_type == ModelType.PROT_T5:
 
84
  attns = None
85
  # Space separate sequences
86
  sequences = [" ".join(sequence) for sequence in sequences]
87
  tokenizer, model = get_protT5()
88
+ else:
89
+ raise ValueError(f"Model {model_type} not supported")
90
+
91
  return attns
92
 
93
  def unidirectional_sum_filtered(attention, layer, head, threshold):
poetry.lock CHANGED
@@ -122,6 +122,17 @@ category = "main"
122
  optional = false
123
  python-versions = "*"
124
 
 
 
 
 
 
 
 
 
 
 
 
125
  [[package]]
126
  name = "beautifulsoup4"
127
  version = "4.11.2"
@@ -451,6 +462,21 @@ docs = ["sphinx (>=3.5)", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "furo"
451
  perf = ["ipython"]
452
  testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "pytest-flake8", "importlib-resources (>=1.3)"]
453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  [[package]]
455
  name = "iniconfig"
456
  version = "2.0.0"
@@ -632,8 +658,10 @@ python-versions = ">=3.7"
632
  attrs = ">=17.4.0"
633
  fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
634
  idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
 
635
  isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
636
  jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
 
637
  pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"
638
  rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
639
  rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
@@ -653,6 +681,7 @@ optional = false
653
  python-versions = ">=3.8"
654
 
655
  [package.dependencies]
 
656
  jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
657
  python-dateutil = ">=2.8.2"
658
  pyzmq = ">=23.0"
@@ -930,6 +959,7 @@ python-versions = ">=3.7"
930
  beautifulsoup4 = "*"
931
  bleach = "*"
932
  defusedxml = "*"
 
933
  jinja2 = ">=3.0"
934
  jupyter-core = ">=4.7"
935
  jupyterlab-pygments = "*"
@@ -1152,6 +1182,7 @@ python-versions = ">=3.8"
1152
 
1153
  [package.dependencies]
1154
  numpy = [
 
1155
  {version = ">=1.21.0", markers = "python_version >= \"3.10\""},
1156
  {version = ">=1.23.2", markers = "python_version >= \"3.11\""},
1157
  ]
@@ -1212,6 +1243,14 @@ python-versions = ">=3.7"
1212
  docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"]
1213
  tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
1214
 
 
 
 
 
 
 
 
 
1215
  [[package]]
1216
  name = "platformdirs"
1217
  version = "3.1.1"
@@ -1425,6 +1464,7 @@ optional = false
1425
  python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
1426
 
1427
  [package.dependencies]
 
1428
  tzdata = {version = "*", markers = "python_version >= \"3.6\""}
1429
 
1430
  [[package]]
@@ -1518,6 +1558,7 @@ python-versions = ">=3.7.0"
1518
  [package.dependencies]
1519
  markdown-it-py = ">=2.2.0,<3.0.0"
1520
  pygments = ">=2.13.0,<3.0.0"
 
1521
 
1522
  [package.extras]
1523
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
@@ -1973,6 +2014,7 @@ optional = false
1973
  python-versions = ">=3.6"
1974
 
1975
  [package.dependencies]
 
1976
  pytz-deprecation-shim = "*"
1977
  tzdata = {version = "*", markers = "platform_system == \"Windows\""}
1978
 
@@ -2091,8 +2133,8 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-co
2091
 
2092
  [metadata]
2093
  lock-version = "1.1"
2094
- python-versions = "^3.10"
2095
- content-hash = "ad6054ae4a119d961e9941f135489d1b89310303aefc27d3132fbd1ed1c35a0f"
2096
 
2097
  [metadata.files]
2098
  altair = []
@@ -2132,6 +2174,7 @@ backcall = [
2132
  {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"},
2133
  {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"},
2134
  ]
 
2135
  beautifulsoup4 = []
2136
  biopython = []
2137
  bleach = []
@@ -2170,6 +2213,7 @@ gitpython = []
2170
  huggingface-hub = []
2171
  idna = []
2172
  importlib-metadata = []
 
2173
  iniconfig = []
2174
  ipykernel = []
2175
  ipyspeck = []
@@ -2240,6 +2284,7 @@ pickleshare = [
2240
  {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
2241
  ]
2242
  pillow = []
 
2243
  platformdirs = []
2244
  pluggy = [
2245
  {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
 
122
  optional = false
123
  python-versions = "*"
124
 
125
+ [[package]]
126
+ name = "backports.zoneinfo"
127
+ version = "0.2.1"
128
+ description = "Backport of the standard library zoneinfo module"
129
+ category = "main"
130
+ optional = false
131
+ python-versions = ">=3.6"
132
+
133
+ [package.extras]
134
+ tzdata = ["tzdata"]
135
+
136
  [[package]]
137
  name = "beautifulsoup4"
138
  version = "4.11.2"
 
462
  perf = ["ipython"]
463
  testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-cov", "pytest-enabler (>=1.3)", "packaging", "pyfakefs", "flufl.flake8", "pytest-perf (>=0.9.2)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "pytest-flake8", "importlib-resources (>=1.3)"]
464
 
465
+ [[package]]
466
+ name = "importlib-resources"
467
+ version = "5.12.0"
468
+ description = "Read resources from Python packages"
469
+ category = "main"
470
+ optional = false
471
+ python-versions = ">=3.7"
472
+
473
+ [package.dependencies]
474
+ zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""}
475
+
476
+ [package.extras]
477
+ docs = ["sphinx (>=3.5)", "jaraco.packaging (>=9)", "rst.linker (>=1.9)", "furo", "sphinx-lint", "jaraco.tidelift (>=1.4)"]
478
+ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "flake8 (<5)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-black (>=0.3.7)", "pytest-mypy (>=0.9.1)", "pytest-flake8"]
479
+
480
  [[package]]
481
  name = "iniconfig"
482
  version = "2.0.0"
 
658
  attrs = ">=17.4.0"
659
  fqdn = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
660
  idna = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
661
+ importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""}
662
  isoduration = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
663
  jsonpointer = {version = ">1.13", optional = true, markers = "extra == \"format-nongpl\""}
664
+ pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""}
665
  pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"
666
  rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""}
667
  rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""}
 
681
  python-versions = ">=3.8"
682
 
683
  [package.dependencies]
684
+ importlib-metadata = {version = ">=4.8.3", markers = "python_version < \"3.10\""}
685
  jupyter-core = ">=4.12,<5.0.0 || >=5.1.0"
686
  python-dateutil = ">=2.8.2"
687
  pyzmq = ">=23.0"
 
959
  beautifulsoup4 = "*"
960
  bleach = "*"
961
  defusedxml = "*"
962
+ importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""}
963
  jinja2 = ">=3.0"
964
  jupyter-core = ">=4.7"
965
  jupyterlab-pygments = "*"
 
1182
 
1183
  [package.dependencies]
1184
  numpy = [
1185
+ {version = ">=1.20.3", markers = "python_version < \"3.10\""},
1186
  {version = ">=1.21.0", markers = "python_version >= \"3.10\""},
1187
  {version = ">=1.23.2", markers = "python_version >= \"3.11\""},
1188
  ]
 
1243
  docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"]
1244
  tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"]
1245
 
1246
+ [[package]]
1247
+ name = "pkgutil-resolve-name"
1248
+ version = "1.3.10"
1249
+ description = "Resolve a name to an object."
1250
+ category = "main"
1251
+ optional = false
1252
+ python-versions = ">=3.6"
1253
+
1254
  [[package]]
1255
  name = "platformdirs"
1256
  version = "3.1.1"
 
1464
  python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7"
1465
 
1466
  [package.dependencies]
1467
+ "backports.zoneinfo" = {version = "*", markers = "python_version >= \"3.6\" and python_version < \"3.9\""}
1468
  tzdata = {version = "*", markers = "python_version >= \"3.6\""}
1469
 
1470
  [[package]]
 
1558
  [package.dependencies]
1559
  markdown-it-py = ">=2.2.0,<3.0.0"
1560
  pygments = ">=2.13.0,<3.0.0"
1561
+ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""}
1562
 
1563
  [package.extras]
1564
  jupyter = ["ipywidgets (>=7.5.1,<9)"]
 
2014
  python-versions = ">=3.6"
2015
 
2016
  [package.dependencies]
2017
+ "backports.zoneinfo" = {version = "*", markers = "python_version < \"3.9\""}
2018
  pytz-deprecation-shim = "*"
2019
  tzdata = {version = "*", markers = "platform_system == \"Windows\""}
2020
 
 
2133
 
2134
  [metadata]
2135
  lock-version = "1.1"
2136
+ python-versions = "^3.8.9"
2137
+ content-hash = "c4902982dd427349993b8019f7130f81583d025190b4bd2a2452166c7d146249"
2138
 
2139
  [metadata.files]
2140
  altair = []
 
2174
  {file = "backcall-0.2.0-py2.py3-none-any.whl", hash = "sha256:fbbce6a29f263178a1f7915c1940bde0ec2b2a967566fe1c65c1dfb7422bd255"},
2175
  {file = "backcall-0.2.0.tar.gz", hash = "sha256:5cbdbf27be5e7cfadb448baf0aa95508f91f2bbc6c6437cd9cd06e2a4c215e1e"},
2176
  ]
2177
+ "backports.zoneinfo" = []
2178
  beautifulsoup4 = []
2179
  biopython = []
2180
  bleach = []
 
2213
  huggingface-hub = []
2214
  idna = []
2215
  importlib-metadata = []
2216
+ importlib-resources = []
2217
  iniconfig = []
2218
  ipykernel = []
2219
  ipyspeck = []
 
2284
  {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"},
2285
  ]
2286
  pillow = []
2287
+ pkgutil-resolve-name = []
2288
  platformdirs = []
2289
  pluggy = [
2290
  {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
pyproject.toml CHANGED
@@ -5,7 +5,7 @@ description = "Visualize and analyze attention patterns for protein language mod
5
  authors = ["Aksel Lenes <aksel.lenes@gmail.com>"]
6
 
7
  [tool.poetry.dependencies]
8
- python = "^3.10"
9
  streamlit = "^1.20.0"
10
  stmol = "^0.0.9"
11
  biopython = "^1.81"
 
5
  authors = ["Aksel Lenes <aksel.lenes@gmail.com>"]
6
 
7
  [tool.poetry.dependencies]
8
+ python = "^3.8.9"
9
  streamlit = "^1.20.0"
10
  stmol = "^0.0.9"
11
  biopython = "^1.81"