File size: 5,619 Bytes
3d4d40d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
pattern_size = 6
from collections import Counter
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
stage: int
minibatch: int
start_time: int
completion_time: int
def transform_schedule(schedule, f, b, w, c):
result = []
stage_order = []
local_prev = {}
stages = len(schedule)
for sid, stage in enumerate(schedule):
counter = Counter()
order = []
for p in stage:
if not p.strip():
continue
mb = counter.get(p, 0)
if order:
local_prev[(sid, p, mb)] = order[-1]
order.append((p, mb))
counter.update(p)
stage_order.append(order)
nmb = max(counter.values())
time_map = {}
cost = {
'F': f,
'B': b,
'W': w,
}
def get_time(stage, type, mb):
if (stage, type, mb) in time_map:
return time_map.get((stage, type, mb))
time = 0
if (stage, type, mb) in local_prev:
time = get_time(stage, *local_prev[(stage, type, mb)])
if type in ('F') and stage > 0:
time = max(time, get_time(stage - 1, type, mb) + c)
if type in ('B') and stage + 1< len(schedule):
time = max(time, get_time(stage + 1, type, mb) + c)
# print(f'{stage} {type}:{mb}', time + cost[type])
time_map[(stage, type, mb)] = time + cost[type]
return time_map[(stage, type, mb)]
r = 0
for sid, stage in enumerate(schedule):
r = max(get_time(sid, 'W', nmb - 1) - get_time(sid, 'F', 0) + f, r)
for sid, stage in enumerate(stage_order):
result_stage = []
for p, mb in stage:
result_stage.append(ScheduledNode(
p.upper(),
sid,
mb,
get_time(sid, p, mb) - cost[p],
get_time(sid, p, mb)
)
)
result.append(result_stage)
return result
def process_warmup_without_increasing_peak_mem(schedules, m):
peak_mem = 0
mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
loc = [[{key: -1 for key in ('F', 'B', 'W')} for _ in range(m + 2)] for _ in range(len(schedules))]
cntr = [{key: 0 for key in ('F', 'B', 'W')} for _ in range(len(schedules))]
for sid in range(len(schedules)):
cur = 0
for i in range(len(schedules[sid])):
if schedules[sid][i] in ('F'):
cur += 1
if schedules[sid][i] in ('W'):
cur -= 1
mem[sid][i] = cur
peak_mem = max(peak_mem, cur)
for i in range(len(schedules[0])):
for sid in range(len(schedules)):
if schedules[sid][i] == ' ':
continue
cntr[sid][schedules[sid][i]] += 1
cnt = cntr[sid][schedules[sid][i]]
pos = -1
if cnt > 1:
pos = loc[sid][cnt - 1][schedules[sid][i]]
if schedules[sid][i] == 'W':
pos = max(pos, loc[sid][cnt]['B'])
if schedules[sid][i] == 'F' and sid > 0:
pos = max(pos, loc[sid - 1][cnt]['F'])
if schedules[sid][i] == 'B':
if sid != len(schedules) - 1:
pos = max(pos, loc[sid + 1][cnt]['B'])
else :
pos = max(pos, loc[sid][cnt]['F'])
pos += 1
while schedules[sid][pos] != ' ' and pos < i:
pos += 1
if pos == i:
loc[sid][cnt][schedules[sid][i]] = i
continue
if schedules[sid][i] in ('B', 'W'):
schedules[sid][pos] = schedules[sid][i]
schedules[sid][i] = ' '
if schedules[sid][pos] in ('W'):
for j in range(pos, i):
mem[sid][j] -= 1
loc[sid][cnt][schedules[sid][pos]] = pos
continue
#If F:
if (sid == 0):
print(cnt, pos, i)
place = i
while place > pos and mem[sid][place - 1] < peak_mem:
place -= 1
while place < i and schedules[sid][place] != ' ':
place += 1
if place == i:
loc[sid][cnt][schedules[sid][i]] = i
continue
if (sid == 0):
print(place)
pos = place
schedules[sid][pos] = schedules[sid][i]
schedules[sid][i] = ' '
for j in range(pos, i):
mem[sid][j] += 1
loc[sid][cnt][schedules[sid][pos]] = pos
return schedules
def schedule(p, m, cost):
schedules = [[' ' for _ in range(6 * m + 2 * p + 6)] for _ in range(p)]
f_0, f_1, b_0, b_1= p-1, p+1, p, p + 2
for sid in range(p - 1, -1, -1):
for mid in range((m + 1) // 2):
if mid * 2 < m:
schedules[sid][f_0 + mid * 6], schedules[sid][b_0 + mid * 6] = 'F', 'B'
if mid * 2 + 1 < m:
schedules[sid][f_1 + mid * 6], schedules[sid][b_1 + mid * 6] = 'F', 'B'
f_0 -= 1
f_1 -= 1
b_0 += 1
b_1 += 1
cnt = 0
for i in range(len(schedules[0])):
if schedules[sid][i] == 'B':
cnt += 1
if schedules[sid][i] == ' ' and cnt > 0:
cnt -= 1
schedules[sid][i] = 'W'
schedules = process_warmup_without_increasing_peak_mem(schedules, m)
res = transform_schedule(schedules, *cost)
return res |