gchhablani commited on
Commit
7fe8d4e
1 Parent(s): 324f080

Add missing files

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ *mic_env/*
2
+ **__pycache__**
3
+ *.pyc
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ plotly==5.1.0
2
+ streamlit==0.84.1
3
+ git+https://github.com/huggingface/transformers.git
4
+ torchvision==0.10.0
5
+ mtranslate==1.8
6
+ black==21.7b0
7
+ flax==0.3.4
8
+ sentencepiece==0.1.96
sections/abstract.md ADDED
File without changes
sections/acknowledgements.md ADDED
File without changes
sections/caveats.md ADDED
File without changes
sections/challenges.md ADDED
File without changes
sections/pretraining.md ADDED
File without changes
sections/references.md ADDED
File without changes
sections/social_impact.md ADDED
File without changes
sections/usage.md ADDED
File without changes
session.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Code for managing session state, which is needed for multi-input forms
3
+ # See https://github.com/streamlit/streamlit/issues/1557
4
+ #
5
+ # This code is taken from
6
+ # https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662
7
+ #
8
+ from streamlit.hashing import _CodeHasher
9
+ from streamlit.report_thread import get_report_ctx
10
+ from streamlit.server.server import Server
11
+
12
+
13
+ class _SessionState:
14
+ def __init__(self, session, hash_funcs):
15
+ """Initialize SessionState instance."""
16
+ self.__dict__["_state"] = {
17
+ "data": {},
18
+ "hash": None,
19
+ "hasher": _CodeHasher(hash_funcs),
20
+ "is_rerun": False,
21
+ "session": session,
22
+ }
23
+
24
+ def __call__(self, **kwargs):
25
+ """Initialize state data once."""
26
+ for item, value in kwargs.items():
27
+ if item not in self._state["data"]:
28
+ self._state["data"][item] = value
29
+
30
+ def __getitem__(self, item):
31
+ """Return a saved state value, None if item is undefined."""
32
+ return self._state["data"].get(item, None)
33
+
34
+ def __getattr__(self, item):
35
+ """Return a saved state value, None if item is undefined."""
36
+ return self._state["data"].get(item, None)
37
+
38
+ def __setitem__(self, item, value):
39
+ """Set state value."""
40
+ self._state["data"][item] = value
41
+
42
+ def __setattr__(self, item, value):
43
+ """Set state value."""
44
+ self._state["data"][item] = value
45
+
46
+ def clear(self):
47
+ """Clear session state and request a rerun."""
48
+ self._state["data"].clear()
49
+ self._state["session"].request_rerun()
50
+
51
+ def sync(self):
52
+ """
53
+ Rerun the app with all state values up to date from the beginning to
54
+ fix rollbacks.
55
+ """
56
+ data_to_bytes = self._state["hasher"].to_bytes(self._state["data"], None)
57
+
58
+ # Ensure to rerun only once to avoid infinite loops
59
+ # caused by a constantly changing state value at each run.
60
+ #
61
+ # Example: state.value += 1
62
+ if self._state["is_rerun"]:
63
+ self._state["is_rerun"] = False
64
+
65
+ elif self._state["hash"] is not None:
66
+ if self._state["hash"] != data_to_bytes:
67
+ self._state["is_rerun"] = True
68
+ self._state["session"].request_rerun()
69
+
70
+ self._state["hash"] = data_to_bytes
71
+
72
+
73
+ def _get_session():
74
+ session_id = get_report_ctx().session_id
75
+ session_info = Server.get_current()._get_session_info(session_id)
76
+
77
+ if session_info is None:
78
+ raise RuntimeError("Couldn't get your Streamlit Session object.")
79
+
80
+ return session_info.session
81
+
82
+
83
+ def _get_state(hash_funcs=None):
84
+ session = _get_session()
85
+
86
+ if not hasattr(session, "_custom_session_state"):
87
+ session._custom_session_state = _SessionState(session, hash_funcs)
88
+
89
+ return session._custom_session_state
utils.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import numpy as np
4
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
5
  from torchvision.transforms.functional import InterpolationMode
6
- from transformers import MBart50TokenizerFast
7
  from PIL import Image
8
 
9
 
@@ -28,24 +27,8 @@ class Transform(torch.nn.Module):
28
 
29
  transform = Transform(224)
30
 
31
-
32
  def get_transformed_image(image):
33
  if image.shape[-1] == 3 and isinstance(image, np.ndarray):
34
  image = image.transpose(2, 0, 1)
35
  image = torch.tensor(image)
36
- return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()
37
-
38
- tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50")
39
-
40
- language_mapping = {
41
- "english": "en_XX",
42
- "german": "de_DE",
43
- "french": "fr_XX",
44
- "spanish": "es_XX"
45
- }
46
-
47
- def generate_sequence(model, pixel_values, lang_code):
48
- lang_code = language_mapping[lang_code]
49
- output_ids = model.generate(input_ids=pixel_values, decoder_start_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=4)
50
- output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64)
51
- return output_sequence
3
  import numpy as np
4
  from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
5
  from torchvision.transforms.functional import InterpolationMode
 
6
  from PIL import Image
7
 
8
 
27
 
28
  transform = Transform(224)
29
 
 
30
  def get_transformed_image(image):
31
  if image.shape[-1] == 3 and isinstance(image, np.ndarray):
32
  image = image.transpose(2, 0, 1)
33
  image = torch.tensor(image)
34
+ return transform(image).unsqueeze(0).permute(0, 2, 3, 1).numpy()