import json from asyncio import Event import ray from mtranslate.core import translate from ray.actor import ActorHandle from tqdm import tqdm ray.init() from typing import Tuple # Back on the local node, once you launch your remote Ray tasks, call # `print_until_done`, which will feed everything back into a `tqdm` counter. @ray.remote class ProgressBarActor: counter: int delta: int event: Event def __init__(self) -> None: self.counter = 0 self.delta = 0 self.event = Event() def update(self, num_items_completed: int) -> None: """Updates the ProgressBar with the incremental number of items that were just completed. """ self.counter += num_items_completed self.delta += num_items_completed self.event.set() async def wait_for_update(self) -> Tuple[int, int]: """Blocking call. Waits until somebody calls `update`, then returns a tuple of the number of updates since the last call to `wait_for_update`, and the total number of completed items. """ await self.event.wait() self.event.clear() saved_delta = self.delta self.delta = 0 return saved_delta, self.counter def get_counter(self) -> int: """ Returns the total number of complete items. """ return self.counter class ProgressBar: progress_actor: ActorHandle total: int description: str pbar: tqdm def __init__(self, total: int, description: str = ""): # Ray actors don't seem to play nice with mypy, generating # a spurious warning for the following line, # which we need to suppress. The code is fine. self.progress_actor = ProgressBarActor.remote() # type: ignore self.total = total self.description = description @property def actor(self) -> ActorHandle: """Returns a reference to the remote `ProgressBarActor`. When you complete tasks, call `update` on the actor. """ return self.progress_actor def print_until_done(self) -> None: """Blocking call. Do this after starting a series of remote Ray tasks, to which you've passed the actor handle. Each of them calls `update` on the actor. When the progress meter reaches 100%, this method returns. """ pbar = tqdm(desc=self.description, total=self.total) while True: delta, counter = ray.get(self.actor.wait_for_update.remote()) pbar.update(delta) if counter >= self.total: pbar.close() return with open("answer_reverse_mapping.json") as f: answer_reverse_mapping = json.load(f) @ray.remote def translate_answer(value, pba): temp = {} for lang in ["fr", "es", "de"]: temp.update({lang: translate(value, lang, "en")}) pba.update.remote(1) return temp translation_dicts = [] pb = ProgressBar(len(answer_reverse_mapping.values())) actor = pb.actor for value in answer_reverse_mapping.values(): translation_dicts.append(translate_answer.remote(value, actor)) pb.print_until_done() translation_dict = dict( zip(answer_reverse_mapping.values(), ray.get(translation_dicts)) ) with open("translation_dict.json", "w") as f: json.dump(translation_dict, f)