Spaces:
Runtime error
Runtime error
import json | |
from asyncio import Event | |
import ray | |
from mtranslate.core import translate | |
from ray.actor import ActorHandle | |
from tqdm import tqdm | |
ray.init() | |
from typing import Tuple | |
# Back on the local node, once you launch your remote Ray tasks, call | |
# `print_until_done`, which will feed everything back into a `tqdm` counter. | |
class ProgressBarActor: | |
counter: int | |
delta: int | |
event: Event | |
def __init__(self) -> None: | |
self.counter = 0 | |
self.delta = 0 | |
self.event = Event() | |
def update(self, num_items_completed: int) -> None: | |
"""Updates the ProgressBar with the incremental | |
number of items that were just completed. | |
""" | |
self.counter += num_items_completed | |
self.delta += num_items_completed | |
self.event.set() | |
async def wait_for_update(self) -> Tuple[int, int]: | |
"""Blocking call. | |
Waits until somebody calls `update`, then returns a tuple of | |
the number of updates since the last call to | |
`wait_for_update`, and the total number of completed items. | |
""" | |
await self.event.wait() | |
self.event.clear() | |
saved_delta = self.delta | |
self.delta = 0 | |
return saved_delta, self.counter | |
def get_counter(self) -> int: | |
""" | |
Returns the total number of complete items. | |
""" | |
return self.counter | |
class ProgressBar: | |
progress_actor: ActorHandle | |
total: int | |
description: str | |
pbar: tqdm | |
def __init__(self, total: int, description: str = ""): | |
# Ray actors don't seem to play nice with mypy, generating | |
# a spurious warning for the following line, | |
# which we need to suppress. The code is fine. | |
self.progress_actor = ProgressBarActor.remote() # type: ignore | |
self.total = total | |
self.description = description | |
def actor(self) -> ActorHandle: | |
"""Returns a reference to the remote `ProgressBarActor`. | |
When you complete tasks, call `update` on the actor. | |
""" | |
return self.progress_actor | |
def print_until_done(self) -> None: | |
"""Blocking call. | |
Do this after starting a series of remote Ray tasks, to which you've | |
passed the actor handle. Each of them calls `update` on the actor. | |
When the progress meter reaches 100%, this method returns. | |
""" | |
pbar = tqdm(desc=self.description, total=self.total) | |
while True: | |
delta, counter = ray.get(self.actor.wait_for_update.remote()) | |
pbar.update(delta) | |
if counter >= self.total: | |
pbar.close() | |
return | |
with open("answer_reverse_mapping.json") as f: | |
answer_reverse_mapping = json.load(f) | |
def translate_answer(value, pba): | |
temp = {} | |
for lang in ["fr", "es", "de"]: | |
temp.update({lang: translate(value, lang, "en")}) | |
pba.update.remote(1) | |
return temp | |
translation_dicts = [] | |
pb = ProgressBar(len(answer_reverse_mapping.values())) | |
actor = pb.actor | |
for value in answer_reverse_mapping.values(): | |
translation_dicts.append(translate_answer.remote(value, actor)) | |
pb.print_until_done() | |
translation_dict = dict( | |
zip(answer_reverse_mapping.values(), ray.get(translation_dicts)) | |
) | |
with open("translation_dict.json", "w") as f: | |
json.dump(translation_dict, f) | |