Multilingual-VQA / translate_answer_mapping.py
gchhablani's picture
Fix style
e289356
raw history blame
No virus
3.37 kB
import json
from asyncio import Event
import ray
from mtranslate.core import translate
from ray.actor import ActorHandle
from tqdm import tqdm
ray.init()
from typing import Tuple
# Back on the local node, once you launch your remote Ray tasks, call
# `print_until_done`, which will feed everything back into a `tqdm` counter.
@ray.remote
class ProgressBarActor:
counter: int
delta: int
event: Event
def __init__(self) -> None:
self.counter = 0
self.delta = 0
self.event = Event()
def update(self, num_items_completed: int) -> None:
"""Updates the ProgressBar with the incremental
number of items that were just completed.
"""
self.counter += num_items_completed
self.delta += num_items_completed
self.event.set()
async def wait_for_update(self) -> Tuple[int, int]:
"""Blocking call.
Waits until somebody calls `update`, then returns a tuple of
the number of updates since the last call to
`wait_for_update`, and the total number of completed items.
"""
await self.event.wait()
self.event.clear()
saved_delta = self.delta
self.delta = 0
return saved_delta, self.counter
def get_counter(self) -> int:
"""
Returns the total number of complete items.
"""
return self.counter
class ProgressBar:
progress_actor: ActorHandle
total: int
description: str
pbar: tqdm
def __init__(self, total: int, description: str = ""):
# Ray actors don't seem to play nice with mypy, generating
# a spurious warning for the following line,
# which we need to suppress. The code is fine.
self.progress_actor = ProgressBarActor.remote() # type: ignore
self.total = total
self.description = description
@property
def actor(self) -> ActorHandle:
"""Returns a reference to the remote `ProgressBarActor`.
When you complete tasks, call `update` on the actor.
"""
return self.progress_actor
def print_until_done(self) -> None:
"""Blocking call.
Do this after starting a series of remote Ray tasks, to which you've
passed the actor handle. Each of them calls `update` on the actor.
When the progress meter reaches 100%, this method returns.
"""
pbar = tqdm(desc=self.description, total=self.total)
while True:
delta, counter = ray.get(self.actor.wait_for_update.remote())
pbar.update(delta)
if counter >= self.total:
pbar.close()
return
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
@ray.remote
def translate_answer(value, pba):
temp = {}
for lang in ["fr", "es", "de"]:
temp.update({lang: translate(value, lang, "en")})
pba.update.remote(1)
return temp
translation_dicts = []
pb = ProgressBar(len(answer_reverse_mapping.values()))
actor = pb.actor
for value in answer_reverse_mapping.values():
translation_dicts.append(translate_answer.remote(value, actor))
pb.print_until_done()
translation_dict = dict(
zip(answer_reverse_mapping.values(), ray.get(translation_dicts))
)
with open("translation_dict.json", "w") as f:
json.dump(translation_dict, f)