Actual source code: mpimatmatmatmult.c

petsc-3.13.1 2020-05-02
Report Typos and Errors
  1: /*
  2:   Defines matrix-matrix-matrix product routines for MPIAIJ matrices
  3:           D = A * B * C
  4: */
  5:  #include <../src/mat/impls/aij/mpi/mpiaij.h>

  7: #if defined(PETSC_HAVE_HYPRE)
  8: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,PetscReal,Mat);
  9: PETSC_INTERN PetscErrorCode MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Mat,Mat,Mat,Mat);

 11: PETSC_INTERN PetscErrorCode MatProductNumeric_ABC_Transpose_AIJ_AIJ(Mat RAP)
 12: {
 14:   Mat_Product    *product = RAP->product;
 15:   Mat            Rt,R=product->A,A=product->B,P=product->C;

 18:   MatTransposeGetMat(R,&Rt);
 19:   MatTransposeMatMatMultNumeric_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,RAP);
 20:   return(0);
 21: }

 23: PETSC_INTERN PetscErrorCode MatProductSymbolic_ABC_Transpose_AIJ_AIJ(Mat RAP)
 24: {
 26:   Mat_Product    *product = RAP->product;
 27:   Mat            Rt,R=product->A,A=product->B,P=product->C;
 28:   PetscBool      flg;

 31:   /* local sizes of matrices will be checked by the calling subroutines */
 32:   MatTransposeGetMat(R,&Rt);
 33:   PetscObjectTypeCompareAny((PetscObject)Rt,&flg,MATSEQAIJ,MATSEQAIJMKL,MATMPIAIJ,NULL);
 34:   if (!flg) SETERRQ1(PetscObjectComm((PetscObject)Rt),PETSC_ERR_SUP,"Not for matrix type %s",((PetscObject)Rt)->type_name);
 35:   MatTransposeMatMatMultSymbolic_AIJ_AIJ_AIJ_wHYPRE(Rt,A,P,product->fill,RAP);
 36:   RAP->ops->productnumeric = MatProductNumeric_ABC_Transpose_AIJ_AIJ;
 37:   return(0);
 38: }

 40: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_Transpose_AIJ_AIJ(Mat C)
 41: {
 43:   Mat_Product    *product = C->product;

 46:   MatSetType(C,MATAIJ);
 47:   if (product->type == MATPRODUCT_ABC) {
 48:     C->ops->productsymbolic = MatProductSymbolic_ABC_Transpose_AIJ_AIJ;
 49:   } else SETERRQ1(PetscObjectComm((PetscObject)C),PETSC_ERR_SUP,"MatProduct type %s is not supported for Transpose, AIJ and AIJ matrices",MatProductTypes[product->type]);
 50:   return(0);
 51: }
 52: #endif

 54: PetscErrorCode MatFreeIntermediateDataStructures_MPIAIJ_BC(Mat ABC)
 55: {
 56:   Mat_MPIAIJ        *a = (Mat_MPIAIJ*)ABC->data;
 57:   Mat_MatMatMatMult *matmatmatmult = a->matmatmatmult;
 58:   PetscErrorCode    ierr;

 61:   if (!matmatmatmult) return(0);

 63:   MatDestroy(&matmatmatmult->BC);
 64:   ABC->ops->destroy = matmatmatmult->destroy;
 65:   PetscFree(a->matmatmatmult);
 66:   return(0);
 67: }

 69: PetscErrorCode MatDestroy_MPIAIJ_MatMatMatMult(Mat A)
 70: {
 71:   PetscErrorCode    ierr;

 74:   (*A->ops->freeintermediatedatastructures)(A);
 75:   (*A->ops->destroy)(A);
 76:   return(0);
 77: }

 79: PetscErrorCode MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,PetscReal fill,Mat D)
 80: {
 82:   Mat            BC;
 83:   PetscBool      scalable;
 84:   Mat_Product    *product = D->product;

 87:   MatCreate(PetscObjectComm((PetscObject)A),&BC);
 88:   if (product) {
 89:     PetscStrcmp(product->alg,"scalable",&scalable);
 90:   } else SETERRQ(PetscObjectComm((PetscObject)D),PETSC_ERR_ARG_NULL,"Call MatProductCreate() first");

 92:   if (scalable) {
 93:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(B,C,fill,BC);
 94:     MatZeroEntries(BC); /* initialize value entries of BC */
 95:     MatMatMultSymbolic_MPIAIJ_MPIAIJ(A,BC,fill,D);
 96:   } else {
 97:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(B,C,fill,BC);
 98:     MatZeroEntries(BC); /* initialize value entries of BC */
 99:     MatMatMultSymbolic_MPIAIJ_MPIAIJ_nonscalable(A,BC,fill,D);
100:   }
101:   product->Dwork = BC;

103:   D->ops->matmatmultnumeric = MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ;
104:   D->ops->freeintermediatedatastructures = MatFreeIntermediateDataStructures_MPIAIJ_BC;
105:   return(0);
106: }

108: PetscErrorCode MatMatMatMultNumeric_MPIAIJ_MPIAIJ_MPIAIJ(Mat A,Mat B,Mat C,Mat D)
109: {
111:   Mat_Product    *product = D->product;
112:   Mat            BC = product->Dwork;

115:   (BC->ops->matmultnumeric)(B,C,BC);
116:   (D->ops->matmultnumeric)(A,BC,D);
117:   return(0);
118: }

120: /* ----------------------------------------------------- */
121: PetscErrorCode MatDestroy_MPIAIJ_RARt(Mat C)
122: {
124:   Mat_MPIAIJ     *c    = (Mat_MPIAIJ*)C->data;
125:   Mat_RARt       *rart = c->rart;

128:   MatDestroy(&rart->Rt);

130:   C->ops->destroy = rart->destroy;
131:   if (C->ops->destroy) {
132:     (*C->ops->destroy)(C);
133:   }
134:   PetscFree(rart);
135:   return(0);
136: }

138: PetscErrorCode MatProductNumeric_RARt_MPIAIJ_MPIAIJ(Mat C)
139: {
141:   Mat_MPIAIJ     *c = (Mat_MPIAIJ*)C->data;
142:   Mat_RARt       *rart = c->rart;
143:   Mat_Product    *product = C->product;
144:   Mat            A=product->A,R=product->B,Rt=rart->Rt;

147:   MatTranspose(R,MAT_REUSE_MATRIX,&Rt);
148:   (C->ops->matmatmultnumeric)(R,A,Rt,C);
149:   return(0);
150: }

152: PetscErrorCode MatProductSymbolic_RARt_MPIAIJ_MPIAIJ(Mat C)
153: {
154:   PetscErrorCode      ierr;
155:   Mat_Product         *product = C->product;
156:   Mat                 A=product->A,R=product->B,Rt;
157:   PetscReal           fill=product->fill;
158:   Mat_RARt            *rart;
159:   Mat_MPIAIJ          *c;

162:   MatTranspose(R,MAT_INITIAL_MATRIX,&Rt);
163:   /* product->Dwork is used to store A*Rt in MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ() */
164:   MatMatMatMultSymbolic_MPIAIJ_MPIAIJ_MPIAIJ(R,A,Rt,fill,C);
165:   C->ops->productnumeric = MatProductNumeric_RARt_MPIAIJ_MPIAIJ;

167:   /* create a supporting struct */
168:   PetscNew(&rart);
169:   c        = (Mat_MPIAIJ*)C->data;
170:   c->rart  = rart;
171:   rart->Rt = Rt;
172:   rart->destroy   = C->ops->destroy;
173:   C->ops->destroy = MatDestroy_MPIAIJ_RARt;
174:   return(0);
175: }