LCOV - code coverage report
Current view: top level - impl - baddbmm_grad_kernel_impl.h Hit Total Coverage
Test: Diff Coverage Lines: 0 75 0.0 %
Date: 2025-01-16 04:37:07

          Line data    Source code
       1             : /* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
       2             : 
       3             : Licensed under the Apache License, Version 2.0 (the "License");
       4             : you may not use this file except in compliance with the License.
       5             : You may obtain a copy of the License at
       6             : 
       7             :     http://www.apache.org/licenses/LICENSE-2.0
       8             : 
       9             : Unless required by applicable law or agreed to in writing, software
      10             : distributed under the License is distributed on an "AS IS" BASIS,
      11             : WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12             : See the License for the specific language governing permissions and
      13             : limitations under the License. */
      14             : #pragma once
      15             : 
      16             : #include <type_traits>
      17             : 
      18             : #include "glog/logging.h"
      19             : 
      20             : #include "paddle/phi/common/amp_type_traits.h"
      21             : #include "paddle/phi/kernels/baddbmm_grad_kernel.h"
      22             : #include "paddle/phi/kernels/funcs/blas/blas.h"
      23             : #include "paddle/phi/kernels/funcs/eigen/common.h"
      24             : #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
      25             : #include "paddle/phi/kernels/funcs/for_range.h"
      26             : 
      27             : namespace phi {
      28             : 
      29             : template <typename T>
      30             : struct BCopyOrScaleFunctor {
      31             :   BCopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel)
      32             :       : scale_(scale), x_(x), output_(output), numel_(numel) {}
      33             : 
      34             :   HOSTDEVICE void operator()(int64_t idx) const {
      35             :     using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
      36             :     const MPType mp_scale = static_cast<MPType>(scale_);
      37             :     const MPType mp_x = static_cast<MPType>(x_[idx]);
      38             :     output_[idx] = static_cast<T>(mp_scale * mp_x);
      39             :   }
      40             : 
      41             :  private:
      42             :   const float scale_;
      43             :   const T* x_;
      44             :   T* output_;
      45             :   int64_t numel_;
      46             : };
      47             : 
      48             : template <typename T,
      49             :           size_t D,
      50             :           int MajorType = Eigen::RowMajor,
      51             :           typename IndexType = Eigen::DenseIndex>
      52             : using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
      53             : 
      54             : using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
      55             : using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
      56             : using Array3 = Eigen::DSizes<Eigen::DenseIndex, 3>;
      57             : 
      58             : template <typename T, typename Context>
      59           0 : void BaddbmmGradKernel(const Context& dev_ctx,
      60             :                        const DenseTensor& input,
      61             :                        const DenseTensor& x,
      62             :                        const DenseTensor& y,
      63             :                        const DenseTensor& out_grad,
      64             :                        float alpha,
      65             :                        float beta,
      66             :                        DenseTensor* input_grad,
      67             :                        DenseTensor* x_grad,
      68             :                        DenseTensor* y_grad) {
      69             :   using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
      70           0 :   bool is_float16_or_bfloat16 = false;
      71             :   if (std::is_same<T, phi::dtype::float16>::value ||
      72             :       std::is_same<T, phi::dtype::bfloat16>::value) {
      73             :     is_float16_or_bfloat16 = true;
      74             :   }
      75             : 
      76           0 :   auto in_dims = input.dims();
      77           0 :   int total_elems = 0;
      78             : 
      79           0 :   VLOG(3) << "alpha: " << alpha << " beta: " << beta;
      80             : 
      81           0 :   if (input_grad != nullptr) {
      82           0 :     input_grad->set_lod(out_grad.lod());
      83             :   }
      84           0 :   if (x_grad != nullptr) {
      85           0 :     x_grad->set_lod(x.lod());
      86             :   }
      87           0 :   if (y_grad != nullptr) {
      88           0 :     y_grad->set_lod(y.lod());
      89             :   }
      90             : 
      91           0 :   auto blas = funcs::GetBlas<Context, T>(dev_ctx);
      92           0 :   auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
      93           0 :   if (input_grad) {
      94           0 :     dev_ctx.template Alloc<T>(input_grad);
      95           0 :     total_elems = in_dims[0] * in_dims[1] * in_dims[2];
      96           0 :     auto& place = *dev_ctx.eigen_device();
      97           0 :     auto eigen_dout = PhiEigenTensor<T, 3>::From(out_grad);
      98           0 :     auto eigen_dinput = PhiEigenTensor<T, 3>::From(*input_grad);
      99             : 
     100           0 :     bool batch_compress = in_dims[0] != out_grad.dims()[0];
     101           0 :     bool row_compress = in_dims[1] != out_grad.dims()[1];
     102           0 :     bool col_compress = in_dims[2] != out_grad.dims()[2];
     103           0 :     auto eigen_dinput_shape = Array3(
     104           0 :         input_grad->dims()[0], input_grad->dims()[1], input_grad->dims()[2]);
     105             : 
     106           0 :     if (batch_compress && row_compress && col_compress) {
     107             :       if (!is_float16_or_bfloat16) {
     108           0 :         eigen_dinput.device(place) =
     109           0 :             eigen_dout.sum().eval().reshape(eigen_dinput_shape);
     110             :       } else {
     111             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     112             :                                          .sum()
     113             :                                          .eval()
     114             :                                          .reshape(eigen_dinput_shape)
     115             :                                          .template cast<T>();
     116             :       }
     117           0 :     } else if (batch_compress && row_compress) {
     118             :       if (!is_float16_or_bfloat16) {
     119           0 :         eigen_dinput.device(place) =
     120           0 :             eigen_dout.sum(Array2(0, 1)).eval().reshape(eigen_dinput_shape);
     121             :       } else {
     122             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     123             :                                          .sum(Array2(0, 1))
     124             :                                          .eval()
     125             :                                          .reshape(eigen_dinput_shape)
     126             :                                          .template cast<T>();
     127             :       }
     128           0 :     } else if (batch_compress && col_compress) {
     129             :       if (!is_float16_or_bfloat16) {
     130           0 :         eigen_dinput.device(place) =
     131           0 :             eigen_dout.sum(Array2(0, 2)).eval().reshape(eigen_dinput_shape);
     132             :       } else {
     133             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     134             :                                          .sum(Array2(0, 2))
     135             :                                          .eval()
     136             :                                          .reshape(eigen_dinput_shape)
     137             :                                          .template cast<T>();
     138             :       }
     139           0 :     } else if (row_compress && col_compress) {
     140             :       if (!is_float16_or_bfloat16) {
     141           0 :         eigen_dinput.device(place) =
     142           0 :             eigen_dout.sum(Array2(1, 2)).eval().reshape(eigen_dinput_shape);
     143             :       } else {
     144             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     145             :                                          .sum(Array2(1, 2))
     146             :                                          .eval()
     147             :                                          .reshape(eigen_dinput_shape)
     148             :                                          .template cast<T>();
     149             :       }
     150           0 :     } else if (batch_compress) {
     151             :       if (!is_float16_or_bfloat16) {
     152           0 :         eigen_dinput.device(place) =
     153           0 :             eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
     154             :       } else {
     155             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     156             :                                          .sum(Array1(0))
     157             :                                          .eval()
     158             :                                          .reshape(eigen_dinput_shape)
     159             :                                          .template cast<T>();
     160             :       }
     161           0 :     } else if (row_compress) {
     162             :       if (!is_float16_or_bfloat16) {
     163           0 :         eigen_dinput.device(place) =
     164           0 :             eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
     165             :       } else {
     166             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     167             :                                          .sum(Array1(1))
     168             :                                          .eval()
     169             :                                          .reshape(eigen_dinput_shape)
     170             :                                          .template cast<T>();
     171             :       }
     172           0 :     } else if (col_compress) {
     173             :       if (!is_float16_or_bfloat16) {
     174           0 :         eigen_dinput.device(place) =
     175           0 :             eigen_dout.sum(Array1(2)).eval().reshape(eigen_dinput_shape);
     176             :       } else {
     177             :         eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
     178             :                                          .sum(Array1(2))
     179             :                                          .eval()
     180             :                                          .reshape(eigen_dinput_shape)
     181             :                                          .template cast<T>();
     182             :       }
     183             :     } else {
     184             :       // The VCOPY does not support the float16, bfloat16
     185             :       if (!is_float16_or_bfloat16) {
     186           0 :         mt_blas.VCOPY(
     187             :             total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
     188             :       } else {
     189             :         phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
     190             :         BCopyOrScaleFunctor<T> functor(
     191             :             1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
     192             :         for_range(functor);
     193             :       }
     194             :     }
     195             : 
     196             :     // The SCAL does not support the float16, bfloat16
     197             :     if (!is_float16_or_bfloat16) {
     198           0 :       mt_blas.SCAL(total_elems, beta, input_grad->data<MPType>());
     199             :     } else {
     200             :       phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
     201             :       BCopyOrScaleFunctor<T> functor(
     202             :           beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
     203             :       for_range(functor);
     204             :     }
     205             :   }
     206           0 :   if (x_grad) {
     207           0 :     dev_ctx.template Alloc<T>(x_grad);
     208           0 :     total_elems = x.dims()[0] * x.dims()[1] * x.dims()[2];
     209             :     // x_grad = out_grad * y'. x_grad: B x M x K, out_grad : B x M x N, y : B x
     210             :     // K x N
     211           0 :     for (int i = 0; i < x.dims()[0]; ++i) {
     212           0 :       auto out_grad_slice = out_grad.Slice(i, i + 1);
     213           0 :       auto y_slice = y.Slice(i, i + 1);
     214           0 :       auto x_grad_slice = x_grad->Slice(i, i + 1);
     215           0 :       auto x_grad_dims = x_grad_slice.dims();
     216             : 
     217           0 :       x_grad_slice.Resize({x_grad_dims[1], x_grad_dims[2]});
     218           0 :       y_slice.Resize({y_slice.dims()[1], y_slice.dims()[2]});
     219           0 :       out_grad_slice.Resize(
     220           0 :           {out_grad_slice.dims()[1], out_grad_slice.dims()[2]});
     221           0 :       blas.MatMul(out_grad_slice, false, y_slice, true, &x_grad_slice);
     222             :     }
     223             :     if (!is_float16_or_bfloat16) {
     224           0 :       mt_blas.SCAL(total_elems, alpha, x_grad->data<MPType>());
     225             :     } else {
     226             :       phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
     227             :       BCopyOrScaleFunctor<T> functor(
     228             :           alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
     229             :       for_range(functor);
     230             :     }
     231             :   }
     232           0 :   if (y_grad) {
     233           0 :     dev_ctx.template Alloc<T>(y_grad);
     234           0 :     total_elems = y.dims()[0] * y.dims()[1] * y.dims()[2];
     235             :     // y_grad = x' * out_grad. y_grad: B x K x N, out_grad : B x M x N, x : B x
     236             :     // M x K
     237           0 :     for (int i = 0; i < x.dims()[0]; ++i) {
     238           0 :       auto out_grad_slice = out_grad.Slice(i, i + 1);
     239           0 :       auto x_slice = x.Slice(i, i + 1);
     240           0 :       auto y_grad_slice = y_grad->Slice(i, i + 1);
     241           0 :       out_grad_slice.Resize(
     242           0 :           {out_grad_slice.dims()[1], out_grad_slice.dims()[2]});
     243           0 :       x_slice.Resize({x_slice.dims()[1], x_slice.dims()[2]});
     244           0 :       y_grad_slice.Resize({y_grad_slice.dims()[1], y_grad_slice.dims()[2]});
     245           0 :       blas.MatMul(x_slice, true, out_grad_slice, false, &y_grad_slice);
     246             :     }
     247             :     if (!is_float16_or_bfloat16) {
     248           0 :       mt_blas.SCAL(total_elems, alpha, y_grad->data<MPType>());
     249             :     } else {
     250             :       phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
     251             :       BCopyOrScaleFunctor<T> functor(
     252             :           alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
     253             :       for_range(functor);
     254             :     }
     255             :   }
     256           0 : }
     257             : 
     258             : }  // namespace phi

Generated by: LCOV version 1.16