Efficient in-place transpose of multiple square float matrices

I have seen a few articles on efficient fast transposition of matrices, but they are not in place and have distinct source and destination memory locations.

In a situation where one has multiple moderate sized square matrices, and need to transpose all in-place.
For example 640 matrices of size (640,640).

For this task I am already using 2GB of memory on a GTX780ti, so cannot afford to allocate an additional memory buffer.

So the most simple kernel I initially came up with is this (for one square matrix in this example)

__global__ void inplace_square_transpose_simple(
			float *Matrix,
			const int N){
	const int i = blockIdx.x*blockDim.x + threadIdx.x;
	const int j = blockIdx.y;

	if(i<N && j<i){
		const float temp=Matrix[j*N+i];
		Matrix[j*N+i]=Matrix[i*N+j];
		Matrix[i*N+j]=temp;
	}
}

where i ranges from 0 to N-1, and j ranges from 0:N-2, with launch:

inplace_square_transpose_simple<<<dim3((NN+THREADS_SMALL-1)/THREADS_SMALL,NN,1),THREADS_SMALL>>>(D_matrix,NN);

which runs at a ratio of bandwidth of 16:1 favoring the GPU.

Not sure if shared memory will make a difference in this case for the in-place transposition, but evidently it did in the out-of-place implementation so maybe there is a better way?

In this year’s PPoPP conference two papers on in-place matrix composition have been presented. One by the group of Wen-Mei W. Hwu titled “In-place transposition of rectangular matrices on accelerators” (PDF). And one by Nvidia titled “A decomposition for in-place matrix transposition” (PDF).

A single 640x640 matrix transposition should be large enough to have enough threads to keep a GPU busy. You don’t need new storage for all the matrices at once, you just need one extra set of 640x640 storage in order to not do an in-place transpose, and be able to use the efficient code described here:

[url]http://devblogs.nvidia.com/parallelforall/efficient-matrix-transpose-cuda-cc/[/url]

Take the first matrix, transpose it out-of-place to your new set of 640x640 storage. Then take the next matrix and transpose it out-of-place to the location occupied originally by the first matrix. Repeat this process for each matrix in sequence.

Even if you don’t like this approach, if you study the description given in the parallel forall blog, you will see that the matrix is transposed a tile at a time. For this square case, and since the tiles are also square, the source tile and destination tile effectively change places, so it should be possible to modify that algorithm to an in-place one, as long as 2 tiles (source and destination) are both loaded into shared memory, before any writes occur. Then two tile processing operations will be performed per threadblock (and only half as many threadblocks will be needed), and the operation can be done in-place.

Threadblocks on the main diagonal will require special-casing. Those tiles (blockIdx.x = blockIdx.y) have a source and destination tile location that is already the same, so they can be processed in a fashion identical to what is given in the parallel forall blog post. No need to load source and destination tiles before processing, they are the same.

Here’s a fully-worked example, demonstrating the last case I mention above, i.e. the in-place transpose modification to the parallel forall blog method. The parallel forall blog method is handled by the transposeCoalesced kernel, and the in-place variant is the iptransposeCoalesced kernel:

$ cat t469.cu
#include <stdio.h>
#include <cublas_v2.h>
#include <time.h>
#include <sys/time.h>
#define uS_PER_SEC 1000000
#define uS_PER_mS 1000
#define N 4096
#define M 4096
#define TILE_DIM 32
#define BLOCK_ROWS 8

__global__ void transposeCoalesced(float *odata, const float *idata)
{
  __shared__ float tile[TILE_DIM][TILE_DIM+1];

  int x = blockIdx.x * TILE_DIM + threadIdx.x;
  int y = blockIdx.y * TILE_DIM + threadIdx.y;
  int width = gridDim.x * TILE_DIM;

  for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
     tile[threadIdx.y+j][threadIdx.x] = idata[(y+j)*width + x];

  __syncthreads();

  x = blockIdx.y * TILE_DIM + threadIdx.x;  // transpose block offset
  y = blockIdx.x * TILE_DIM + threadIdx.y;

  for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
     odata[(y+j)*width + x] = tile[threadIdx.x][threadIdx.y + j];
}

__global__ void iptransposeCoalesced(float *data)
{
  __shared__ float tile_s[TILE_DIM][TILE_DIM+1];
  __shared__ float tile_d[TILE_DIM][TILE_DIM+1];

  int x = blockIdx.x * TILE_DIM + threadIdx.x;
  int y = blockIdx.y * TILE_DIM + threadIdx.y;
  int width = gridDim.x * TILE_DIM;

  if (blockIdx.y>blockIdx.x) { // handle off-diagonal case
    int dx = blockIdx.y * TILE_DIM + threadIdx.x;
    int dy = blockIdx.x * TILE_DIM + threadIdx.y;
    for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
      tile_s[threadIdx.y+j][threadIdx.x] = data[(y+j)*width + x];
    for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
      tile_d[threadIdx.y+j][threadIdx.x] = data[(dy+j)*width + dx];
    __syncthreads();
    for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
      data[(dy+j)*width + dx] = tile_s[threadIdx.x][threadIdx.y + j];
    for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
      data[(y+j)*width + x] = tile_d[threadIdx.x][threadIdx.y + j];
  }

  else if (blockIdx.y==blockIdx.x){ // handle on-diagonal case
    for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
      tile_s[threadIdx.y+j][threadIdx.x] = data[(y+j)*width + x];
    __syncthreads();
    for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS)
      data[(y+j)*width + x] = tile_s[threadIdx.x][threadIdx.y + j];
  }
}

int validate(const float *mat, const float *mat_t, int n, int m){
   int result = 1;
   for (int i = 0; i < n; i++)
     for (int j = 0; j < m; j++)
       if (mat[(i*m)+j] != mat_t[(j*n)+i]) result = 0;
   return result;
}

int main(){

    timeval t1, t2;
    float *matrix = (float *) malloc (N * M * sizeof(float));
    for (int i = 0; i < N; i ++)
      for (int j = 0; j < M; j++)
        matrix[(i*M) + j] = i;
// Starting the timer
    gettimeofday(&t1, NULL);
    float *matrixT = (float *) malloc (N * M * sizeof(float));
    for (int i = 0; i < N; i++)
        for (int j = 0; j < M; j++)
            matrixT[(j*N)+i] = matrix[(i*M)+j]; // matrix is obviously filled
//Ending the timer
    gettimeofday(&t2, NULL);
    if (!validate(matrix, matrixT, N, M)) {printf("fail!\n"); return 1;}
    float et1 = (((t2.tv_sec*uS_PER_SEC)+t2.tv_usec) - ((t1.tv_sec*uS_PER_SEC)+t1.tv_usec))/(float)uS_PER_mS;
    printf("CPU time = %fms\n", et1);

    float *h_matrixT , *d_matrixT , *d_matrix;
    h_matrixT = (float *) (malloc (N * M * sizeof(float)));
    cudaMalloc((void **)&d_matrixT , N * M * sizeof(float));
    cudaMalloc((void**)&d_matrix , N * M * sizeof(float));
    cudaMemcpy(d_matrix , matrix , N * M * sizeof(float) , cudaMemcpyHostToDevice);

//Starting the timer
    gettimeofday(&t1, NULL);

    const float alpha = 1.0;
    const float beta  = 0.0;
    cublasHandle_t handle;
    //gettimeofday(&t1, NULL);
    cublasCreate(&handle);
    gettimeofday(&t1, NULL);
    cublasSgeam(handle, CUBLAS_OP_T, CUBLAS_OP_N, N, M, &alpha, d_matrix, M, &beta, d_matrix, N, d_matrixT, N);
    cudaDeviceSynchronize();
    gettimeofday(&t2, NULL);
    cublasDestroy(handle);

//Ending the timer
    float et2 = (((t2.tv_sec*uS_PER_SEC)+t2.tv_usec) - ((t1.tv_sec*uS_PER_SEC)+t1.tv_usec))/(float)uS_PER_mS;
    printf("GPU Sgeam time = %fms\n", et2);

    cudaMemcpy(h_matrixT , d_matrixT , N * M * sizeof(float) , cudaMemcpyDeviceToHost);
    if (!validate(matrix, h_matrixT, N, M)) {printf("fail!\n"); return 1;}
    cudaMemset(d_matrixT,0, N*M*sizeof(float));
    memset(h_matrixT, 0, N*M*sizeof(float));
    dim3 threads(TILE_DIM, BLOCK_ROWS);
    dim3 blocks(N/TILE_DIM, M/TILE_DIM);
    gettimeofday(&t1, NULL);
    transposeCoalesced<<<blocks, threads >>>(d_matrixT, d_matrix);
    cudaDeviceSynchronize();
    gettimeofday(&t2, NULL);
    cudaMemcpy(h_matrixT , d_matrixT , N * M * sizeof(float) , cudaMemcpyDeviceToHost);
    if (!validate(matrix, h_matrixT, N, M)) {printf("fail!\n"); return 1;}
    float et3 = (((t2.tv_sec*uS_PER_SEC)+t2.tv_usec) - ((t1.tv_sec*uS_PER_SEC)+t1.tv_usec))/(float)uS_PER_mS;
    printf("GPU kernel time = %fms\n", et3);

    memset(h_matrixT, 0, N*M*sizeof(float));
    gettimeofday(&t1, NULL);
    iptransposeCoalesced<<<blocks, threads >>>(d_matrix);
    cudaDeviceSynchronize();
    gettimeofday(&t2, NULL);
    cudaMemcpy(h_matrixT , d_matrix , N * M * sizeof(float) , cudaMemcpyDeviceToHost);
    if (!validate(matrix, h_matrixT, N, M)) {printf("fail!\n"); return 1;}
    float et4 = (((t2.tv_sec*uS_PER_SEC)+t2.tv_usec) - ((t1.tv_sec*uS_PER_SEC)+t1.tv_usec))/(float)uS_PER_mS;
    printf("GPU in-place kernel time = %fms\n", et4);

cudaFree(d_matrix);
    cudaFree(d_matrixT);
    return 0;
}
$ nvcc -arch=sm_20 -o t469 t469.cu -lcublas
$ ./t469
CPU time = 450.095001ms
GPU Sgeam time = 1.937000ms
GPU kernel time = 1.694000ms
GPU in-place kernel time = 1.839000ms
$

For your 640x640 case, this would launch ~200 non-idle blocks, which is probably enough to keep most GPUs busy (~16 threadblocks per SM on K20, for example). You might get some additional benefit on newer GPUs (cc3.5 or newer) by launching the in-place transpose kernel concurrently on several matrices at a time.

txbob && GertJan,

Thanks, will give these a try on both GTX 780ti and K20c.

Very good implementation txbob. Here is the profiling output using a GTX 780ti Windows 7 compute 3.5;

ConsoleApplication1.exe
CPU solution timing: 275
==900== NVPROF is profiling process 900, command: ConsoleApplication1.exe
GPU cublas solution timing: 1
GPU kernel out-of-place solution timing: 1
GPU kernel in-place solution timing: 1
==900== Profiling application: ConsoleApplication1.exe
==900== Profiling result:
   Start  Duration            Grid Size      Block Size     Regs*    SSMem*    DSMem*      Size  Throughput           Device   Context    Stream  Name
125.49ms  14.002ms                    -               -         -         -         -  67.109MB  4.7927GB/s  GeForce GTX 780         1         1  [CUDA memcpy HtoD]
358.72ms  1.4400us                    -               -         -         -         -      112B  77.778MB/s  GeForce GTX 780         1         1  [CUDA memcpy HtoD]
359.11ms  520.42us            (64 64 1)       (256 1 1)        39  8.3200KB        0B         -           -  GeForce GTX 780         1         1  void transpose_readWrite_alignment_kernel<float, i
nt=1, bool=0, int=6, int=5, int=3>(cublasTransposeParams<float>, float const *, float*, float const *) [201]
360.91ms  19.576ms                    -               -         -         -         -  67.109MB  3.4282GB/s  GeForce GTX 780         1         1  [CUDA memcpy DtoH]
556.99ms  256.93us                    -               -         -         -         -  67.109MB  261.20GB/s  GeForce GTX 780         1         1  [CUDA memset]
557.25ms  550.78us          (128 128 1)        (32 8 1)        22  4.2240KB        0B         -           -  GeForce GTX 780         1         1  transposeCoalesced(float*, float const *) [229]
559.21ms  12.121ms                    -               -         -         -         -  67.109MB  5.5366GB/s  GeForce GTX 780         1         1  [CUDA memcpy DtoH]
747.59ms  513.50us          (128 128 1)        (32 8 1)        26  8.4480KB        0B         -           -  GeForce GTX 780         1         1  iptransposeCoalesced(float*) [234]
749.39ms  12.120ms                    -               -         -         -         -  67.109MB  5.5370GB/s  GeForce GTX 780         1         1  [CUDA memcpy DtoH]

Regs: Number of registers used per CUDA thread.
SSMem: Static shared memory allocated per CUDA block.
DSMem: Dynamic shared memory allocated per CUDA block.

So not too much difference between the three implementations for transpose of 4096x4096 matrix, but the in-place implementation is indeed the fastest at 513 us, with cuBLAS in second place at 520 us.

GTC 2014 presentation by catanzaro (NVIDIA) at

with code at

hello, just seen this topic and working on an in-place transpose, the above forum helped me, but im trying to check also the code which txbox mentioned in link : http://devblogs.nvidia.com/parallelforall/efficient-matrix-transpose-cuda-cc/

below is the total code :

my question is if you noticed for each transpose they like warm up ( call the kerel once) , then for-loop 100 times to the same kernel ,. why ? why warm up + why 100 times loop ? never understood this point, hope anyone can help me,… Thx!!

These are standard benchmarking techniques to determine “steady state” performance and reduce the noise level in measurements. A related technique, best-of-N, is used in the well-known STREAM benchmark used to measure memory bandwidth. By default it uses N=10.

The thread is a bit older now, but I still need some advice on the in-place transposition kernel presented here.

__global__ void iptransposeCoalesced(float *data)

How could the kernel be extended so that rectangular matrices can be transposed with the following charachteristics N != M and N = 2^n and M = 2^m+1 (N is the number of rows and M is the columns)?

It’s not a trivial matter to make an in-place rectangular transpose. The articles linked in comments 2,7 suggest the level of complexity.

A simple approach however would be to drop the in-place requirement and just use cublasgeam

If the inplace transpose is a requirement, I would suggest starting with the code already mentioned in comment 7:

[url]https://github.com/BryanCatanzaro/inplace[/url]