1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
| def compress_spec(g, v, out):
j = 0
for i in range(len(g)):
if g[i]:
out[j] = v[i]
j += 1
def compress(g: TT["i", bool], v: TT["i"], i:int) -> TT["i"]:
# Main idea: Using v @ m to map
# Eg: v = [1, 2, 3], g = [0, 1, 1], result = [2, 3, 0]
# [1, 2, 3] @ [[0, 0, 0],
# [1, 0, 0],
# [0, 1, 0]] => [2, 3, 0]
# How to get m?
# `cumsum(1*g) - 1` to get the index of True
# Eg: g = [1, 0, 1, 0, 1] => cumsum(1*g) - 1 = [0, 0, 1, 1, 2]
# `arange(i) == (cumsum(1*g) - 1)[:, None]` to get the matrix
# [[ True, False, False, False, False],
# [ True, False, False, False, False],
# [False, True, False, False, False],
# [False, True, False, False, False],
# [False, False, True, False, False]]
# Finally, we use where(g[:, None], matrix, 0) to get the m
# [[ 1, 0, 0, 0, 0],
# [ 0, 0, 0, 0, 0],
# [ 0, 1, 0, 0, 0],
# [ 0, 0, 0, 0, 0],
# [ 0, 0, 1, 0, 0]]
return v @ where(g[:, None], arange(i) == (cumsum(1*g) - 1)[:, None], 0)
|