/*
 * Copyright (c) 2017, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

/* clang-format off */

/* mmreal4.c -- F90 fast-/dgemm-like MATMUL intrinsics for real*4 type */

#include "stdioInterf.h"
#include "fioMacros.h"

#define SMALL_ROWSA 10
#define SMALL_ROWSB 10
#define SMALL_COLSB 10

void ENTF90(MMUL_REAL4, mmul_real4)(int ta, int tb, __POINT_T mra,
                                    __POINT_T ncb, __POINT_T kab, float *alpha,
                                    float a[], __POINT_T lda, float b[],
                                    __POINT_T ldb, float *beta, float c[],
                                    __POINT_T ldc)

{
  /*
  *   Notes on parameters
  *   ta = 0 -> no transpose of matrix a
  *   tb = 0 -> no transpose of matrix b

  *   mra = number of rows in matrices a and c ( = m )
  *   ncb = number of columns in matrices b and c ( = n )
  *   kab = shared dimension of matrices a and b ( = k, but need k elsewhere )
  *   a = starting address of matrix a
  *   b = starting address of matrix b
  *   c = starting address of matric c
  *   lda = leading dimension of matrix a
  *   ldb = leading dimension of matrix b
  *   ldc = leading dimension of matrix c
  *   alpha = 1.0
  *   beta = 0.0
  *   Note that these last two conditions are inconsitent with the general
  *   case for dgemm.
  *   Taken together we have
  *   c = beta * c + alpha * ( (ta)a * (tb)*b )
  *   where the meaning of (ta) and (tb) is that if ta = 0 a is not transposed
  *   and transposed otherwise and if tb = 0, b is not transpose and transposed
  *   otherwise.
  */

  // Local variables

  int colsa, rowsa, rowsb, colsb;
  int ar, ac;
  int ndx, ndxsav, colchunk, colchunks, rowchunk, rowchunks;
  int colsb_chunks, colsb_end, colsb_strt;
  int bufr, bufc, loc, lor;
  int small_size = SMALL_ROWSA * SMALL_ROWSB * SMALL_COLSB;
  int tindex = 0;
  float buffera[SMALL_ROWSA * SMALL_ROWSB];
  float bufferb[SMALL_COLSB * SMALL_ROWSB];
  float temp;
  void ftn_mvmul_real4_(), ftn_vmmul_real4_();
  void ftn_mnaxnb_real4_(), ftn_mnaxtb_real4_();
  void ftn_mtaxnb_real4_(), ftn_mtaxtb_real4_();
  float calpha, cbeta;
  /*
   * Small matrix multiply variables
   */
  int i, ia, ja, j, k, bk;
  int astrt, bstrt, cstrt, andx, bndx, cndx, indx, indx_strt;
  /*
   * tindex has the following meaning:
   * ta == 0, tb == 0: tindex = 0
   * ta == 1, tb == 0: tindex = 1
   * ta == 0, tb == 1; tindex = 2
   * ta == 1, tb == 1; tindex = 3
   */

  /*  if( ( tb == 0 ) && ( ncb == 1 ) && ( ldc == 1 ) ){ */
  if ((tb == 0) && (ncb == 1)) {
    /* matrix vector multiply */
    ftn_mvmul_real4_(&ta, &mra, &kab, alpha, a, &lda, b, beta, c);
    return;
  }
  if ((ta == 0) && (mra == 1) && (ldc == 1)) {
    /* vector matrix multiply */
    ftn_vmmul_real4_(&tb, &ncb, &kab, alpha, a, b, &ldb, beta, c);
    return;
  }
  calpha = *alpha;
  cbeta = *beta;
  rowsa = mra;
  colsa = kab;
  rowsb = kab;
  colsb = ncb;
  if (ta == 1)
    tindex = 1;

  if (tb == 1)
    tindex += 2;

  // Check for really small matrix sizes

  // Check for really small matrix sizes

  if ((colsb <= SMALL_COLSB) && (rowsa <= SMALL_ROWSA) &&
      (rowsb <= SMALL_ROWSB)) {
    switch (tindex) {
    case 0: /* matrix a and matrix b normally oriented
             *
             * The notation here refers to the Fortran orientation since
             * that is the origination of these matrices
             */
      astrt = 0;
      bstrt = 0;
      cstrt = 0;
      if (cbeta == (float)0.0) {
        for (i = 0; i < rowsa; i++) {
          /* Transpose the a row of the a matrix */
          andx = astrt;
          indx = 0;
          for (ja = 0; ja < colsa; ja++) {
            buffera[indx++] = calpha * a[andx];
            andx += lda;
          }
          astrt++;
          cndx = cstrt;
          for (j = 0; j < colsb; j++) {
            temp = 0.0;
            bndx = bstrt;
            for (k = 0; k < rowsb; k++)
              temp += buffera[k] * b[bndx++];
            bstrt += ldb;
            c[cndx] = temp;
            cndx += ldc;
          }
          cstrt++; /* set index for next row of c */
          bstrt = 0;
        }
      } else {
        for (i = 0; i < rowsa; i++) {
          /* Transpose the a row of the a matrix */
          andx = astrt;
          indx = 0;
          for (ja = 0; ja < colsa; ja++) {
            buffera[indx++] = calpha * a[andx];
            andx += lda;
          }
          astrt++;
          cndx = cstrt;
          for (j = 0; j < colsb; j++) {
            temp = 0.0;
            bndx = bstrt;
            for (k = 0; k < rowsb; k++)
              temp += buffera[k] * b[bndx++];
            bstrt += ldb;
            c[cndx] = temp + cbeta * c[cndx];
            cndx += ldc;
          }
          cstrt++; /* set index for next row of c */
          bstrt = 0;
        }
      }

      break;
    case 1: /* matrix a transpose, matrix b normally oriented */
      bndx = 0;
      cstrt = 0;
      andx = 0;
      if (cbeta == (float)0.0) {
        for (j = 0; j < colsb; j++) {
          cndx = cstrt;
          for (i = 0; i < rowsa; i++) {
            /* Matrix a need not be transposed */
            temp = 0.0;
            for (k = 0; k < rowsb; k++)
              temp += a[andx + k] * b[bndx + k];
            c[cndx] = calpha * temp;
            andx += lda;
            cndx++;
          }
          cstrt += ldc; /* set index for next column of c */
          astrt++;      /* set index for next column of a */
          b += ldb;
          andx = 0;
        }
      } else {
        for (j = 0; j < colsb; j++) {
          cndx = cstrt;
          for (i = 0; i < rowsa; i++) {
            /* Matrix a need not be transposed */
            temp = 0.0;
            for (k = 0; k < rowsb; k++)
              temp += a[andx + k] * b[bndx + k];
            c[cndx] = calpha * temp + cbeta * c[cndx];
            andx += lda;
            cndx++;
          }
          cstrt += ldc; /* set index for next column of c */
          astrt++;      /* set index for next column of a */
          b += ldb;
          andx = 0;
        }
      }

      break;
    case 2: /* Matrix a normal, b transposed */
      /* We will transpose b and work with transposed rows of a */
      /* Transpose matrix b */
      indx_strt = 0;
      bstrt = 0;
      for (j = 0; j < rowsb; j++) {
        indx = indx_strt;
        bndx = bstrt;
        for (i = 0; i < colsb; i++) {
          bufferb[indx] = calpha * b[bndx++];
          indx += rowsb;
        }
        indx_strt++;
        bstrt += ldb;
      }
      /* All of b is now transposed */

      astrt = 0;
      cstrt = 0;
      if (cbeta == (float)0.0) {
        for (i = 0; i < rowsa; i++) {
          /* Transpose the a row of the a matrix */
          andx = astrt;
          indx = 0;
          for (ja = 0; ja < colsa; ja++) {
            buffera[indx++] = a[andx];
            andx += lda;
          }
          cndx = cstrt;
          bndx = 0;
          for (j = 0; j < colsb; j++) {
            temp = 0.0;
            for (k = 0; k < rowsb; k++)
              temp += buffera[k] * bufferb[bndx++];
            c[cndx] = temp;
            cndx += ldc;
          }
          cstrt++; /* set index for next row of c */
          astrt++;
        }
      } else {
        for (i = 0; i < rowsa; i++) {
          /* Transpose the a row of the a matrix */
          andx = astrt;
          indx = 0;
          for (ja = 0; ja < colsa; ja++) {
            buffera[indx++] = a[andx];
            andx += lda;
          }
          cndx = cstrt;
          bndx = 0;
          for (j = 0; j < colsb; j++) {
            temp = 0.0;
            for (k = 0; k < rowsb; k++)
              temp += buffera[k] * bufferb[bndx++];
            c[cndx] = temp + cbeta * c[cndx];
            cndx += ldc;
          }
          cstrt++; /* set index for next row of c */
          astrt++;
        }
      }
      break;
    case 3: /* both matrices tranposed. Combination of cases 1 and 2 */
      /* Transpose matrix b */

      indx_strt = 0;
      bstrt = 0;
      for (j = 0; j < rowsb; j++) {
        indx = indx_strt;
        bndx = bstrt;
        for (i = 0; i < colsb; i++) {
          bufferb[indx] = calpha * b[bndx++];
          indx += rowsb;
        }
        indx_strt++;
        bstrt += ldb;
      }

      /* All of b is now transposed */
      andx = 0;
      cstrt = 0;
      bndx = 0;
      if (cbeta == (float)0.0) {
        for (i = 0; i < colsb; i++) {
          /* Matrix a need not be transposed */
          cndx = cstrt;
          for (j = 0; j < rowsa; j++) {
            temp = 0.0;
            for (k = 0; k < rowsb; k++)
              temp += a[andx + k] * bufferb[bndx + k];
            c[cndx] = temp;
            cndx++;
            andx += lda;
          }
          bndx += rowsb; /* index for next transposed column of b */
          andx = 0;      /* set index for next column of a */
          cstrt += ldc;  /* set index for next row of c */
        }
      } else {
        for (i = 0; i < colsb; i++) {
          /* Matrix a need not be transposed */
          cndx = cstrt;
          for (j = 0; j < rowsa; j++) {
            temp = 0.0;
            for (k = 0; k < rowsb; k++)
              temp += a[andx + k] * bufferb[bndx + k];
            c[cndx] = temp + cbeta * c[cndx];
            cndx++;
            andx += lda;
          }
          bndx += rowsb; /* index for next transposed column of b */
          andx = 0;      /* set index for next column of a */
        }
      }
    }
  } else {
    switch (tindex) {
    case 0:
      ftn_mnaxnb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
                          &ldc);
      break;
    case 1:
      ftn_mtaxnb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
                          &ldc);
      break;
    case 2:
      ftn_mnaxtb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
                          &ldc);
      break;
    case 3:
      ftn_mtaxtb_real4_(&mra, &ncb, &kab, alpha, a, &lda, b, &ldb, beta, c,
                          &ldc);
    }
  }

}
