Spaces:
Runtime error
Runtime error
File size: 3,370 Bytes
2bbf92c e289356 2bbf92c e289356 690384a 2bbf92c 690384a 2bbf92c 690384a 2bbf92c 690384a 2bbf92c 690384a 2bbf92c 690384a 2bbf92c 690384a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import json
from asyncio import Event
import ray
from mtranslate.core import translate
from ray.actor import ActorHandle
from tqdm import tqdm
ray.init()
from typing import Tuple
# Back on the local node, once you launch your remote Ray tasks, call
# `print_until_done`, which will feed everything back into a `tqdm` counter.
@ray.remote
class ProgressBarActor:
counter: int
delta: int
event: Event
def __init__(self) -> None:
self.counter = 0
self.delta = 0
self.event = Event()
def update(self, num_items_completed: int) -> None:
"""Updates the ProgressBar with the incremental
number of items that were just completed.
"""
self.counter += num_items_completed
self.delta += num_items_completed
self.event.set()
async def wait_for_update(self) -> Tuple[int, int]:
"""Blocking call.
Waits until somebody calls `update`, then returns a tuple of
the number of updates since the last call to
`wait_for_update`, and the total number of completed items.
"""
await self.event.wait()
self.event.clear()
saved_delta = self.delta
self.delta = 0
return saved_delta, self.counter
def get_counter(self) -> int:
"""
Returns the total number of complete items.
"""
return self.counter
class ProgressBar:
progress_actor: ActorHandle
total: int
description: str
pbar: tqdm
def __init__(self, total: int, description: str = ""):
# Ray actors don't seem to play nice with mypy, generating
# a spurious warning for the following line,
# which we need to suppress. The code is fine.
self.progress_actor = ProgressBarActor.remote() # type: ignore
self.total = total
self.description = description
@property
def actor(self) -> ActorHandle:
"""Returns a reference to the remote `ProgressBarActor`.
When you complete tasks, call `update` on the actor.
"""
return self.progress_actor
def print_until_done(self) -> None:
"""Blocking call.
Do this after starting a series of remote Ray tasks, to which you've
passed the actor handle. Each of them calls `update` on the actor.
When the progress meter reaches 100%, this method returns.
"""
pbar = tqdm(desc=self.description, total=self.total)
while True:
delta, counter = ray.get(self.actor.wait_for_update.remote())
pbar.update(delta)
if counter >= self.total:
pbar.close()
return
with open("answer_reverse_mapping.json") as f:
answer_reverse_mapping = json.load(f)
@ray.remote
def translate_answer(value, pba):
temp = {}
for lang in ["fr", "es", "de"]:
temp.update({lang: translate(value, lang, "en")})
pba.update.remote(1)
return temp
translation_dicts = []
pb = ProgressBar(len(answer_reverse_mapping.values()))
actor = pb.actor
for value in answer_reverse_mapping.values():
translation_dicts.append(translate_answer.remote(value, actor))
pb.print_until_done()
translation_dict = dict(
zip(answer_reverse_mapping.values(), ray.get(translation_dicts))
)
with open("translation_dict.json", "w") as f:
json.dump(translation_dict, f)
|