Loop lifting fails
Commit c0ba0c98
Consider the following program. matmulT
only has to be computed N / Br
times.
use Array: all;
#define d 128
#define N 1024
#define Br 128
#define Bc 256
noinline float[n:shp] id(float[n:shp] x) { return x; }
noinline
float[m, n] matmulT(float[m, k] A, float[n, k] B)
{
return {iv -> tof(0) | iv < [m, n]};
}
inline
float FlashAttention(float[N, d] Q, float[N, d] K, float[N, d] V)
{
Qb = reshape([N / Br, Br, d], Q);
Kb = reshape([N / Bc, Bc, d], K);
O = tof(0);
for (j = 0; j < N / Bc; j++) {
Pj = {[i, a] -> matmulT(Qb[i], Kb[j])[a]
| [i, a] < [N / Br, Br]};
O += sum(Pj);
}
return O;
}
int main()
{
Q = id({iv -> tof(1) | iv < [N, d]});
K = id({iv -> tof(1) | iv < [N, d]});
V = id({iv -> tof(1) | iv < [N, d]});
O = FlashAttention(Q, K, V);
return _toi_S_(O);
}
However, the optimised code gives
/* Partn */
([ 0, 0 ] <= _flat_82=[i, a] (IDXS:_wlidx_920_Pj) < [ 8, 128 ] genwidth [ 8, 128 ])
{
_ivesli_930 = _idxs2offset_( [ 8, 128, 128 ], i, _iveras_1039, _iveras_1040);
_flat_84 = with /** FOLDABLE (all gen's const) **/
/** REFERENCED: 1 (total num refs) **/
{
/* Partn */
([ 0, 0 ] <= _pinl_540_iv=[_pinl_543__eat_146, _pinl_542__eat_145] (IDXS:_wlidx_921__flat_84) < [ 128, 128 ] genwidth [ 128, 128 ])
{
_ivesli_932 = _idxs2offset_( [ 8, 128, 128 ], _iveras_1041, _pinl_543__eat_146, _pinl_542__eat_145);
_ivesli_933 = _add_SxS_( _ivesli_930, _ivesli_932);
_pinl_538__flat_396 = _idx_sel_( _ivesli_933, Qb);
} : _pinl_538__flat_396 ;
} :
genarray( [ 128, 128 ], _pinl_450__flat_393, IDX(_wlidx_921__flat_84));
_flat_83 = _MAIN::matmulT( _flat_84, _flat_56) ;
computing it N / Br * Br
times. The [i, a]
loop should have been split up in an [i]
and [a]
loop, the matmulT
lifted out of the [a]
loop, and then inside of the [a]
loop a suballoc can be done.