/*
 *             ATLAS Altivec matmul kernel
 *           (C) Copyright 2001 Nicholas A. Coult           
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *   1. Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *   2. Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions, and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *   3. The name of the University of Tennessee, the ATLAS group,
 *      or the names of its contributers may not be used to endorse
 *      or promote products derived from this software without specific
 *      written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE. 
 *
 */

#include "atlas_misc.h"

/* ----- Useful macros ----- */
#define minus_zero_constant  (vector float)vec_sl((vector unsigned long)(-1),(vector unsigned long)(-1))

#define vec_splat_float(theVector,theFloat) {  (*((float *)&theVector)) = theFloat; theVector = vec_splat(theVector,0); }

#define transpose(x0,x1,x2,x3) \
s0 = vec_mergeh(x0,x2); \
s1 = vec_mergeh(x1,x3); \
s2 = vec_mergel(x0,x2); \
s3 = vec_mergel(x1,x3); \
x0 = vec_mergeh(s0,s1); \
x1 = vec_mergel(s0,s1); \
x2 = vec_mergeh(s2,s3); \
x3 = vec_mergel(s2,s3)

#define store_unaligned(v, where) \
low = vec_ld(0, (where)); \
high = vec_ld(16, (where)); \
p_vector = vec_lvsr(0, (int *)(where)); \
mask  = vec_perm((vector unsigned char)(0), (vector unsigned char)(-1), p_vector); \
v = vec_perm(v, v, p_vector); \
low = vec_sel(low,  v, mask); \
high = vec_sel(v, high, mask); \
vec_st(low,  0, (where)); \
vec_st(high, 16, (where))

#define load_unaligned(u,v) \
p_vector = (vector unsigned char)vec_lvsl(0, (int*)(v)); \
low = (vector unsigned char)vec_ld(0, (v)); \
high = (vector unsigned char)vec_ld(16, (v)); \
u=(vector float)vec_perm(low, high, p_vector)

/* ----- Altivec L1 matmul cleanup kernel ------ */

void ATL_USERMM
(const int M, const int N, const int K, const float alpha, const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc)
{
  /* These are used in unaligned loads/stores */
  vector unsigned char low, high;
  vector unsigned char p_vector;
  vector unsigned char mask;
  
  /* This holds beta */
  vector float v_beta;
  
  /* Rows of A^T, Columns of B, rows of C */
  vector float ar0, ar1, ar2, ar3;
  vector float bc0, bc1, bc2, bc3;
  vector float cr0, cr1, cr2, cr3;
  
  /* Temp variables used in matmul */
  vector float s00, s01, s02, s03;
  vector float s10, s11, s12, s13;
  vector float s20, s21, s22, s23;
  vector float s30, s31, s32, s33;
  vector float s0, s1, s2, s3;

  int i, j, k;
  
  vector float minus_zero = minus_zero_constant;



  vec_splat_float(v_beta, beta); 
  
  for (j=0; j < N; j+=4) /* Loop over columns of B, columns of C */
    {
      for (i=0; i < M; i+=4)  /* Loop over rows of A^T, rows of C */
	{
          /* Load rows of C */
#ifdef BETA0
	  cr0 = cr1 = cr2 = cr3 = minus_zero;
#else
          /* We actually load columns first */
	  load_unaligned(cr0,(vector unsigned char *)&(C[i+j*ldc]));
	  load_unaligned(cr1,(vector unsigned char *)&(C[i+(j+1)*ldc]));
	  load_unaligned(cr2,(vector unsigned char *)&(C[i+(j+2)*ldc]));
	  load_unaligned(cr3,(vector unsigned char *)&(C[i+(j+3)*ldc]));
	  
          /* Then transpose to make cr0, ..cr3 hold the _rows_ of C */
	  transpose(cr0, cr1, cr2, cr3);
#endif
	  
	  s00 = s01 = s02 = s03 =
	    s10 = s11 = s12 = s13 =
	    s20 = s21 = s22 = s23 =
	    s30 = s31 = s32 = s33 = minus_zero;
	  
	  for (k=0; k < K; k+=4) /* Loop over columns of A^T, rows of B */
	    {
              /* Load rows of A^T */
              load_unaligned(ar0,(vector unsigned char *) &(A[k+i*lda]));
              load_unaligned(ar1,(vector unsigned char *) &(A[k+(i+1)*lda]));
              load_unaligned(ar2,(vector unsigned char *) &(A[k+(i+2)*lda]));
              load_unaligned(ar3,(vector unsigned char *) &(A[k+(i+3)*lda]));
              
              /* Load columns of B */
              load_unaligned(bc0,(vector unsigned char *) &(B[k+j*ldb]));
              load_unaligned(bc1,(vector unsigned char *) &(B[k+(j+1)*ldb]));
              load_unaligned(bc2,(vector unsigned char *) &(B[k+(j+2)*ldb]));
              load_unaligned(bc3,(vector unsigned char *) &(B[k+(j+3)*ldb]));
              
              /* Accumulate products of rows and columns;
                 later these will become dot products */
	      s00 = vec_madd(ar0, bc0, s00);
	      s01 = vec_madd(ar0, bc1, s01);
	      s02 = vec_madd(ar0, bc2, s02);
	      s03 = vec_madd(ar0, bc3, s03);

	      s10 = vec_madd(ar1, bc0, s10);
	      s11 = vec_madd(ar1, bc1, s11);
	      s12 = vec_madd(ar1, bc2, s12);
	      s13 = vec_madd(ar1, bc3, s13);

	      s20 = vec_madd(ar2, bc0, s20);
	      s21 = vec_madd(ar2, bc1, s21);
	      s22 = vec_madd(ar2, bc2, s22);
	      s23 = vec_madd(ar2, bc3, s23);	      

	      s30 = vec_madd(ar3, bc0, s30);
	      s31 = vec_madd(ar3, bc1, s31);
	      s32 = vec_madd(ar3, bc2, s32);
	      s33 = vec_madd(ar3, bc3, s33);	   	      
	    }

          /* Transpose the accumulated products */
	  transpose(s00, s01, s02, s03);
	  transpose(s10, s11, s12, s13);
	  transpose(s20, s21, s22, s23);
	  transpose(s30, s31, s32, s33);
	  
          /* Sum accumulated products to get dot products of
             rows of A^T w/ columns of B, and add this matrix
             to beta*C. */
    	  s00 = vec_add(s00, s01);
	  s02 = vec_add(s02, s03);

	  s10 = vec_add(s10, s11);
	  s12 = vec_add(s12, s13);
	  
	  s20 = vec_add(s20, s21);
	  s22 = vec_add(s22, s23);
	  
	  s30 = vec_add(s30, s31);
	  s32 = vec_add(s32, s33);
	  
#ifdef BETA0
	  cr0 = vec_add(s00, s02);
	  cr1 = vec_add(s10, s12);
	  cr2 = vec_add(s20, s22);
	  cr3 = vec_add(s30, s32);
#else
  	  cr0 = vec_madd(cr0, v_beta, s00);
	  cr1 = vec_madd(cr1, v_beta, s10);
	  cr2 = vec_madd(cr2, v_beta, s20);
	  cr3 = vec_madd(cr3, v_beta, s30);

	  cr0 = vec_add(cr0, s02);
	  cr1 = vec_add(cr1, s12);
	  cr2 = vec_add(cr2, s22);
	  cr3 = vec_add(cr3, s32);
#endif
	  /* Transpose to get cr0,...,cr3 to hold _columns_ of C */
	  transpose(cr0, cr1, cr2, cr3);
	  
          /* Store C */
	  store_unaligned((vector unsigned char)cr0, (vector unsigned char *)&(C[i+j*ldc]));
	  store_unaligned((vector unsigned char)cr1, (vector unsigned char *)&(C[i+(j+1)*ldc]));
	  store_unaligned((vector unsigned char)cr2, (vector unsigned char *)&(C[i+(j+2)*ldc]));
	  store_unaligned((vector unsigned char)cr3, (vector unsigned char *)&(C[i+(j+3)*ldc])); 
	}
    }
}

