MATLAB: Can A = A + B’*B be sped up somehow? It is seriously bottlenecking the for-loop

MATLABmatrix manipulationmatrix multiplicationoptimizationtransposevectorization

I have a script where a very large matrix A (square, up to about 10000 x 10000 values) is initialized outside of a for-loop and is then overwritten many times within the loop like this:
A = zeros(6588,6588);
for i =1:1000
% B changes to a new value here and is a 27x6588 matrix
A = A + B' * B;
end
I wanted to remove the loop altogether but I think it’s impossible for me to calculate all instances of transpose(B)*B outside of the loop beforehand since I run out of memory.
Is there anything I could do to this code segment to speed it up? Computationally it’s just a few simple operations but they still take over 70% of my scripts runtime and I can’t figure out a way to improve this. Is it possible?

Best Answer

If you have a C compiler available, another option is to call a BLAS routine from within a mex function to do this calculation. It may not save operation time on the B'*B part (not sure if MATLAB is smart enough to call DSYRK instead of DGEMM), but it may save on the time spent in the other parts. That is because A = A + B'*B can be done with one BLAS routine call (this type of update is built into the BLAS library code), as opposed to doing B'*B first into temporary storage, then adding it to A, then free'ing the temporary storage. E.g., the doc for DSYRK shows that it can do the following operation in one call, without forming B' explicitly first:
A := alpha * B^T * B + beta * A
So for your case, alpha = 1 and beta = 1. Also, by utilizing this BLAS call everything can be done in-place (i.e., no need to temporarily form the right-hand-side result first and then assign it to A).
Let me know if you want to explore this route and if so then I can code it up for you.
EDIT: 9/4/2015
SORRY this has taken so long! I wrote the code actually fairly quickly but had problems in debugging. Code worked in 32-bit but not in 64-bit, but identical code I had in another program worked in both 32-bit and 64-bit. Finally I figured out it was because for this new program I was forgetting to include the -largeArrayDims compilation flag in the mex command (needed to get the mwSignedIndex macro to return the correct type). In any event, here is the code (I changed the integers to ptrdiff_t to match blas.h). Hopefully it will shave some execution time off of your runs.
Two things to keep in mind. First, it does all the calculations IN-PLACE! So make sure your accumulator matrix ("A" in your example) is not shared. If you set it with the zeros function prior to the loop, then it will not be shared and you will be OK. Second, it only does one half of the matrix calculations when using two arguments. After you are done with the loop, then call it once with only one argument to fill in the other half. E.g., A calling sequence to accomplish your above example would be:
A = zeros(6588,6588);
for i =1:1000
% B changes to a new value here and is a 27x6588 matrix
dsyrk(A,B); % Does A = A + B'*B in-place, but lower triangle only
end
dsyrk(A); % Fills in the upper triangle from the lower triangle
And here is the mex routine dsyrk.c
/*--------------------------------------------------------------------------------------
* DSYRK does the operation C = C + A' * A in place
*
* Syntax: dsyrk(C,A) --> does C = C + A' * A (only lower triangle part)
* dsyrk(C) --> fills the upper triangle with the lower triangle
*
* C = a real double full N x N matrix
* A = a real double full K x N matrix
*
* The intent is to first initialize C, then call dsyrk(C,A) in a loop for various A
* matrices, then after the loop call dsyrk(C) once to fill in the upper triangle part.
* The code does the operations on C in place, so it is up to the user to make sure
* that C is not shared with any other variable prior to calling dsyrk.
*
* The code uses ptrdiff_t for the integer types because that is what the header file
* blas.h uses for the dsyrk function arguments.
*
* Programmer: James Tursa
*-------------------------------------------------------------------------------------- */
#include "mex.h"
#include "blas.h"
void xFILLPOS(double *Cpr, ptrdiff_t n);
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double alpha = 1.0, beta = 1.0;
double *Apr, *Cpr;
char uplo = 'L';
char trans = 'T';
ptrdiff_t m, n, k, p, lda, ldc;
if( nlhs ) {
mexErrMsgTxt("Too many outputs ... this routine does all operations in-place");
}
if( nrhs == 2 ) {
Cpr = mxGetPr(prhs[0]);
m = mxGetM(prhs[0]);
n = mxGetN(prhs[0]);
Apr = mxGetPr(prhs[1]);
k = mxGetM(prhs[1]);
p = mxGetN(prhs[1]);
if( m != n || !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) ||
mxGetNumberOfDimensions(prhs[0]) != 2 ) {
mexErrMsgTxt("1st input must be real double non-sparse square matrix");
}
if( p != n || !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxIsSparse(prhs[1]) ||
mxGetNumberOfDimensions(prhs[1]) != 2 ) {
mexErrMsgTxt("2nd input must be real double non-sparse matrix compatible with 1st input");
}
lda = k;
ldc = n;
dsyrk( &uplo, &trans, &n, &k, &alpha, Apr, &lda, &beta, Cpr, &ldc );
} else if( nrhs == 1 ) {
Cpr = mxGetPr(prhs[0]);
m = mxGetM(prhs[0]);
n = mxGetN(prhs[0]);
if( m != n || !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]) ||
mxGetNumberOfDimensions(prhs[0]) != 2 ) {
mexErrMsgTxt("1st input must be real double non-sparse square matrix");
}
xFILLPOS(Cpr,n);
} else {
mexErrMsgTxt("Need exactly 1 or 2 inputs");
}
}
/*--------------------------------------------------------------------------------------
* Fill the upper triangle with contents of the lower triangle
*-------------------------------------------------------------------------------------- */
void xFILLPOS(double *Cpr, ptrdiff_t n)
{
double *source, *target;
register ptrdiff_t i, j;
source = Cpr + 1;
target = Cpr + n;
for( i=1; i<n; i++ ) {
for( j=i; j<n; j++ ) {
*target = *source;
target += n;
source++;
}
source += i + 1;
target = source + n - 1;
}
}