Complexity problem in FUNDEF_WRAPPERTYPE
The following program is OOM-killed with 40GB of available space in elim_alpha_types.c
use Array: all;
use StdIO: all;
use Math: all;
#define CATEGORIES 10
inline
float[d:shp] averageOuter(float[n,d:shp] array)
{
return {iv -> sum({[i] -> array[[i]++iv] | [i] < [n]}) | iv < shp} / tof(n);
}
inline
int MaxPos( float[.,.,.,.,.] output)
{
n = shape( output)[0];
max = output[[0,0,0,0,0]];
res = 0;
for( i=0; i<n; i++)
if( output[[i,0,0,0,0]] > max) {
max = output[[i,0,0,0,0]];
res = i;
}
return res;
}
inline
float MeanSquaredError( float[*] result, float[*] labels)
{
return sum ( 0.5f * ( labels - result) * ( labels - result) );
}
#define sumOuter(xo, xi, e, shpi, shpo) \
{xi -> with { \
(0 * shpo <= xo < shpo): (e)[xi]; \
}: fold(+, 0f) \
| xi < shpi}
/* In `Fortran' notation x(i:i+n+1) where i, n are vectors and the range
* is inclusive, exclusive. */
inline
float[d:n1] slide(int[d] i, float[d:m] x, int[d] n) | all(n1 == n + 1)
| all(n + 1 + i <= m)
{
return {iv -> x[iv + i] | iv < n + 1};
}
inline
float[d:mn] backslide(int[d] i, float[d:n1] y, int[d] mn)
| all(i < 1 + mn - n1)
{
return {iv -> 0f | iv < i;
iv -> y[iv - i] | i <= iv < n1 + i;
iv -> 0f | n1 + i <= iv < mn};
}
inline
float[*] logistics(float[*] x)
{
return {iv -> 1f / (1f + exp(-x[iv]))};
}
inline
float[n:oshp,n:ishp] block(float[d:shp] x, int[n] ishp)
| all(shp % ishp == 0)
| all(oshp * ishp == shp)
| (2 * n == d)
{
return {iv -> tile(ishp, iv * ishp, x) | iv < shp / ishp};
}
inline
float[*] unblock(float[*] a, int[.] bshp)
{
shp = drop(-shape(bshp), shape(a)) * bshp;
return {iv -> a[(iv / bshp) ++ (iv % bshp)] | iv < shp};
}
inline
float[n:ishp] selb(float[d:shp] x, int[n] iv, int[n] ishp)
| (2 * n == d)
| all(iv < shp / ishp)
{
return block(x, ishp)[iv];
}
#define IMAPS(iv, e, shp) {iv -> (e)[[0]] | iv < shp}
float sels(float[d:shp] x, int[d] iv)
{
return x[iv];
}
noinline
float[6, 5, 5], float[6], float[12, 6, 5, 5], float[12],
float[CATEGORIES, 12, 1, 4, 4], float[CATEGORIES], float
TrainZhang(float[28, 28] inp, float[6, 5, 5] k1, float[6] b1,
float[12, 6, 5, 5] k2, float[12] b2,
float[CATEGORIES, 12, 1, 4, 4] fc, float[CATEGORIES] b,
float[CATEGORIES, 1, 1, 1, 1] target)
{
c11 = { x20 -> (sumOuter(x21, x22, (slide(x21, inp, [23, 23])) * (IMAPS(x23, (sels((k1)[x20], x21)), ([24, 24]))), ([5, 5]), ([24, 24]))) + (IMAPS(x24, (sels(b1, x20)), ([24, 24]))) | x20 < [6] };
c1 = logistics(c11);
s1 = { x16 -> IMAPS(x17, ((sumOuter(x18, x19, sels(selb((c1)[x16], x17, [2, 2]), x18), ([2, 2]), ([1]))) / 4), ([12, 12])) | x16 < [6] };
c21 = { x11 -> (sumOuter(x12, x13, (slide(x12, s1, [0, 7, 7])) * (IMAPS(x14, (sels((k2)[x11], x12)), ([1, 8, 8]))), ([6, 5, 5]), ([1, 8, 8]))) + (IMAPS(x15, (sels(b2, x11)), ([1, 8, 8]))) | x11 < [12] };
c2 = logistics(c21);
s2 = { x6 -> { x7 -> IMAPS(x8, ((sumOuter(x9, x10, sels(selb(((c2)[x6])[x7], x8, [2, 2]), x9), ([2, 2]), ([1]))) / 4), ([4, 4])) | x7 < [1] } | x6 < [12] };
r1 = { x1 -> (sumOuter(x2, x3, (slide(x2, s2, [0, 0, 0, 0])) * (IMAPS(x4, (sels((fc)[x1], x2)), ([1, 1, 1, 1]))), ([12, 1, 4, 4]), ([1, 1, 1, 1]))) + (IMAPS(x5, (sels(b, x1)), ([1, 1, 1, 1]))) | x1 < [10] };
r = logistics(r1);
ddr = 1f;
ddr1 = (ddr) * ((r1) * ((1f) + (-(r1))));
dds2 = sumOuter(x1, x2, sumOuter(x3, x4, backslide(x3, IMAPS(x5, ((sels((ddr1)[x1], x5)) * (sels((fc)[x1], x3))), ([1, 1, 1, 1])), [12, 1, 4, 4]), ([12, 1, 4, 4]), ([12, 1, 4, 4])), ([10]), ([12, 1, 4, 4]));
ddc2 = { x1 -> { x2 -> unblock({ x3 -> IMAPS(x4, ((sels(((dds2)[x1])[x2], x3)) / 4), ([2, 2])) | x3 < [4, 4] }, [2, 2]) | x2 < [1] } | x1 < [12] };
ddc21 = (ddc2) * ((c21) * ((1f) + (-(c21))));
dds1 = sumOuter(x1, x2, sumOuter(x3, x4, backslide(x3, IMAPS(x5, ((sels((ddc21)[x1], x5)) * (sels((k2)[x1], x3))), ([1, 8, 8])), [6, 12, 12]), ([6, 5, 5]), ([6, 12, 12])), ([12]), ([6, 12, 12]));
ddc1 = { x1 -> unblock({ x2 -> IMAPS(x3, ((sels((dds1)[x1], x2)) / 4), ([2, 2])) | x2 < [12, 12] }, [2, 2]) | x1 < [6] };
ddc11 = (ddc1) * ((c11) * ((1f) + (-(c11))));
ddb = IMAPS(x1, (sumOuter(x2, x3, sels((ddr1)[x1], x2), ([1, 1, 1, 1]), ([1]))), ([10]));
ddfc = { x1 -> IMAPS(x2, (sumOuter(x3, x4, (sels((ddr1)[x1], x3)) * (sels(slide(x2, s2, [0, 0, 0, 0]), x3)), ([1, 1, 1, 1]), ([1]))), ([12, 1, 4, 4])) | x1 < [10] };
ddb2 = IMAPS(x1, (sumOuter(x2, x3, sels((ddc21)[x1], x2), ([1, 8, 8]), ([1]))), ([12]));
ddk2 = { x1 -> IMAPS(x2, (sumOuter(x3, x4, (sels((ddc21)[x1], x3)) * (sels(slide(x2, s1, [0, 7, 7]), x3)), ([1, 8, 8]), ([1]))), ([6, 5, 5])) | x1 < [12] };
ddb1 = IMAPS(x1, (sumOuter(x2, x3, sels((ddc11)[x1], x2), ([24, 24]), ([1]))), ([6]));
ddk1 = { x1 -> IMAPS(x2, (sumOuter(x3, x4, (sels((ddc11)[x1], x3)) * (sels(slide(x2, inp, [23, 23]), x3)), ([24, 24]), ([1]))), ([5, 5])) | x1 < [6] };
ddinp = sumOuter(x1, x2, sumOuter(x3, x4, backslide(x3, IMAPS(x5, ((sels((ddc11)[x1], x5)) * (sels((k1)[x1], x3))), ([24, 24])), [28, 28]), ([5, 5]), ([28, 28])), ([6]), ([28, 28]));
error = MeanSquaredError(r, target);
return (ddk1, ddb1, ddk2, ddb2, ddfc, ddb, error);
}
int main()
{
k1 = genarray([6, 5, 5], 1f / 25f);
b1 = genarray([6], 1f / 6f);
k2 = genarray([12, 6, 5, 5], 1f / 150f);
b2 = genarray([12], 1f / 12f);
fc = genarray([CATEGORIES, 12, 1, 4, 4], 1f / 192f);
b = genarray([CATEGORIES], 1f / tof(CATEGORIES));
target = genarray([CATEGORIES,1,1,1,1], 0f);
inp = reshape([28, 28], tof(iota(28 * 28)));
d_k1, d_b1, d_k2, d_b2, d_fc, d_b, err = TrainZhang(inp, k1, b1, k2, b2, fc, b, target);
print(d_k1);
return 0;
}