I was surprised to see "optimal" leading to a sub-optimal path on a simple-ish contraction of 3 arrays each of 2 or 3 dimensions.
greedy and dynamic-programming do find the correct path.
Script:
import opt_einsum
print(f"OE version: {opt_einsum.__version__}\n")
expr = 'bgk,bkd,bk->bgd'
b = 64
g = 8
k = 4096
d = 128
a_shape = [b, g, k]
v_shape = [b, k, d]
s_shape = [b, k]
print('"OPTIMAL" path')
print(opt_einsum.contract_path(expr, a_shape, v_shape, s_shape, shapes=True, optimize="optimal"))
print()
print("Better path")
print(opt_einsum.contract_path(expr, a_shape, v_shape, s_shape, shapes=True, optimize=[(0, 2), (0, 1)]))
Output:
OE version: 3.4.0
"OPTIMAL" path
([(1, 2), (0, 1)], Complete contraction: bgk,bkd,bk->bgd
Naive scaling: 4
Optimized scaling: 4
Naive FLOP count: 8.053e+8
Optimized FLOP count: 5.704e+8
Theoretical speedup: 1.412e+0
Largest intermediate: 3.355e+7 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 0 bk,bkd->bkd bgk,bkd->bgd
4 0 bkd,bgk->bgd bgd->bgd)
Better path
([(0, 2), (0, 1)], Complete contraction: bgk,bkd,bk->bgd
Naive scaling: 4
Optimized scaling: 4
Naive FLOP count: 8.053e+8
Optimized FLOP count: 5.390e+8
Theoretical speedup: 1.494e+0
Largest intermediate: 2.097e+6 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
3 0 bk,bgk->bkg bkd,bkg->bgd
4 0 bkg,bkd->bgd bgd->bgd)
Seems related to #248
I was surprised to see "optimal" leading to a sub-optimal path on a simple-ish contraction of 3 arrays each of 2 or 3 dimensions.
greedy and dynamic-programming do find the correct path.
Script:
Output:
Seems related to #248