MATLAB: Optimize/speed up a big and slow matrix operation with addition and bsxfun

bsxfunelement wise operationsexecution-timematrix additionoptimizationperformance

Hi. I have a fairly long line of code in my script that takes about 7 seconds to run, it was originally three separate calculations but are now done in one line that looks like this:
X = bsxfun(@times,reshape(bsxfun(@times,A,B),[1440,1,2251]),(1-C)./2)...
+bsxfun(@times,reshape(E,1440,1,2251),D)...
+bsxfun(@times,reshape(F,1440,1,[]),((1+C)/2));
Since I need to run it several tens of thousands times it is causing a fair amount of anxiety in my life at the moment. I don’t know how I would go about speeding it up further or if it is even possible. So I would really appreciate it if any of you optimization gurus here could give me some advice on how and if I can go about making this faster.
The input variables are of size:
A = 2251x1
B = 1440x2251
C = 1x181
D = 1440x181
E = 1440x2251
F = 1440x2251
Thanks!

Best Answer

Assuming you have a C compiler available, here is a simple C mex routine to do the calculation. Could possibly be sped up even more with multi-threading (e.g., OpenMP) as well but I did not include that code below. Let me know if you have an OpenMP compilant compiler since that addition would be fairly simple.
As is (with no C-mex multi-threading) here is an example run on my machine:
Elapsed time is 6.655025 seconds. % M-code with bsxfun
Elapsed time is 1.700077 seconds. % C-mex code (not multi-threaded)
CAUTION: Code below is bare bones with no argument checking.
// X = bsxfun(@times,reshape(bsxfun(@times,A,B),[1440,1,2251]),(1-C)./2)...
// + bsxfun(@times,reshape(E,1440,1,2251),D)...
// + bsxfun(@times,reshape(F,1440,1,[]),((1+C)/2));
//
// The input variables are of size:
//
// A = 1 x 2251
// B = 1440 x 2251
// C = 1 x 181
// D = 1440 x 181
// E = 1440 x 2251
// F = 1440 x 2251
//
// Output:
//
// X = 1440 x 181 x 2251
//
// Calling sequence:
//
// X = this_file_name(A,B,C,D,E,F);
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
mwSize i, j, k, m, n, p, km, ikm, jm;
mwSize dims[3];
double *A, *B, *C, *D, *E, *F, *X, *Cm, *Cp;
double a, cp, cm;
m = mxGetM(prhs[3]);
n = mxGetN(prhs[3]);
p = mxGetN(prhs[4]);
A = mxGetPr(prhs[0]);
B = mxGetPr(prhs[1]);
C = mxGetPr(prhs[2]);
D = mxGetPr(prhs[3]);
E = mxGetPr(prhs[4]);
F = mxGetPr(prhs[5]);
dims[0] = m;
dims[1] = n;
dims[2] = p;
plhs[0] = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
X = mxGetPr(plhs[0]);
Cm = (double *) mxMalloc(n*sizeof(*Cm));
Cp = (double *) mxMalloc(n*sizeof(*Cp));
for( j=0; j<n; j++ ) {
Cm[j] = (1.0 - C[j]) / 2.0;
Cp[j] = (1.0 + C[j]) / 2.0;
}
for( k=0; k<p; k++ ) {
a = A[k];
km = k * m;
for( j=0; j<n; j++ ) {
jm = j * m;
cm = Cm[j] * a;
cp = Cp[j];
for( i=0; i<m; i++ ) {
ikm = i+km;
*X++ = B[ikm] * cm + E[ikm] * D[i+jm] + F[ikm] * cp;
}
}
}
mxFree(Cp);
mxFree(Cm);
}