LCOV - code coverage report
Current view: top level - impl - baddbmm_kernel_impl.h Hit Total Coverage
Test: Diff Coverage Lines: 0 45 0.0 %
Date: 2025-01-16 04:37:07

          Line data    Source code
       1             : /* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
       2             : 
       3             : Licensed under the Apache License, Version 2.0 (the "License");
       4             : you may not use this file except in compliance with the License.
       5             : You may obtain a copy of the License at
       6             : 
       7             :     http://www.apache.org/licenses/LICENSE-2.0
       8             : 
       9             : Unless required by applicable law or agreed to in writing, software
      10             : distributed under the License is distributed on an "AS IS" BASIS,
      11             : WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12             : See the License for the specific language governing permissions and
      13             : limitations under the License. */
      14             : 
      15             : #pragma once
      16             : 
      17             : #include <type_traits>
      18             : 
      19             : #include "glog/logging.h"
      20             : 
      21             : #include "paddle/phi/common/amp_type_traits.h"
      22             : #include "paddle/phi/kernels/baddbmm_kernel.h"
      23             : #include "paddle/phi/kernels/funcs/blas/blas.h"
      24             : #include "paddle/phi/kernels/funcs/eigen/common.h"
      25             : #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
      26             : 
      27             : namespace phi {
      28             : 
      29             : template <typename T,
      30             :           size_t D,
      31             :           int MajorType = Eigen::RowMajor,
      32             :           typename IndexType = Eigen::DenseIndex>
      33             : using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
      34             : 
      35             : using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
      36             : using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
      37             : using Array3 = Eigen::DSizes<Eigen::DenseIndex, 3>;
      38             : 
      39             : template <typename T, typename Context>
      40           0 : void BaddbmmKernel(const Context& dev_ctx,
      41             :                    const DenseTensor& input,
      42             :                    const DenseTensor& x,
      43             :                    const DenseTensor& y,
      44             :                    float beta,
      45             :                    float alpha,
      46             :                    DenseTensor* out) {
      47           0 :   auto input_dims = input.dims();
      48           0 :   auto x_dims = x.dims();
      49           0 :   auto y_dims = y.dims();
      50             : 
      51           0 :   DenseTensor input_3d(input);
      52           0 :   if (input.dims().size() == 2) {
      53           0 :     input_dims = {1, input.dims()[0], input.dims()[1]};
      54           0 :     input_3d.Resize(input_dims);
      55             :   }
      56             : 
      57             :   // broadcast mode check
      58           0 :   if (x_dims[0] != input_dims[0]) {
      59           0 :     PADDLE_ENFORCE_EQ(input_dims[0],
      60             :                       1,
      61             :                       errors::InvalidArgument(
      62             :                           "When x_dims[0] is not equal with input_dims[0], "
      63             :                           "input_dims[0] must be 1 but got %s",
      64             :                           input_dims[0]));
      65           0 :     PADDLE_ENFORCE_EQ(y_dims[2] == input_dims[2] || input_dims[2] == 1,
      66             :                       true,
      67             :                       errors::InvalidArgument(
      68             :                           "The input tensor shape mismatch, input shape=[%s], "
      69             :                           "x shape=[%s], y shape=[%s]",
      70             :                           input_dims,
      71             :                           x_dims,
      72             :                           y_dims));
      73             :   }
      74           0 :   if (y_dims[2] != input_dims[2]) {
      75           0 :     PADDLE_ENFORCE_EQ(input_dims[2],
      76             :                       1,
      77             :                       errors::InvalidArgument(
      78             :                           "When y_dims[2] is not equal with input_dims[2], "
      79             :                           "input_dims[2] must be 1 but got %s",
      80             :                           input_dims[2]));
      81           0 :     PADDLE_ENFORCE_EQ(x_dims[0] == input_dims[0] || input_dims[0] == 1,
      82             :                       true,
      83             :                       errors::InvalidArgument(
      84             :                           "The input tensor shape mismatch, input shape=[%s], "
      85             :                           "x shape=[%s], y shape=[%s]",
      86             :                           input_dims,
      87             :                           x_dims,
      88             :                           y_dims));
      89             :   }
      90           0 :   PADDLE_ENFORCE_EQ(
      91             :       x_dims[2],
      92             :       y_dims[1],
      93             :       errors::InvalidArgument(
      94             :           "The input tensor X's width must be equal with matrix Y' height. "
      95             :           "But received X's shape = [%s], Y's shape = [%s].",
      96             :           x_dims[2],
      97             :           y_dims[1]));
      98             : 
      99           0 :   dev_ctx.template Alloc<T>(out);
     100           0 :   auto blas = funcs::GetBlas<Context, T>(dev_ctx);
     101             : 
     102             :   // calc broadcast dim
     103           0 :   Array3 bcast_dims;
     104           0 :   bcast_dims[0] = x_dims[0] / input_dims[0];
     105           0 :   bcast_dims[1] = x_dims[1] / input_dims[1];
     106           0 :   bcast_dims[2] = y_dims[2] / input_dims[2];
     107           0 :   VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << ","
     108           0 :           << bcast_dims[2] << "]";
     109             : 
     110             :   // broadcast using eigen
     111           0 :   const DenseTensor& const_ref_input = input_3d;
     112           0 :   auto eigen_input = PhiEigenTensor<T, 3>::From(const_ref_input);
     113           0 :   auto eigen_out = PhiEigenTensor<T, 3>::From(*out);
     114           0 :   auto& place = *dev_ctx.eigen_device();
     115           0 :   funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 3>::Eval(
     116             :       place, eigen_out, eigen_input, bcast_dims);
     117             : 
     118             :   using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
     119             : 
     120             :   // special case for MPType
     121             :   if constexpr (std::is_same_v<MPType, float>) {
     122           0 :     VLOG(4) << "Function: baddbmm, Type of T: " << typeid(T).name();
     123           0 :     VLOG(4) << "Function: baddbmm, Type of MPType: " << typeid(MPType).name();
     124           0 :     float t_alpha = alpha;
     125           0 :     float t_beta = beta;
     126           0 :     if (x_dims[0] == 1) {
     127           0 :       blas.GEMM(CblasNoTrans,
     128             :                 CblasNoTrans,
     129             :                 x_dims[1],
     130             :                 y_dims[2],
     131             :                 x_dims[2],
     132             :                 t_alpha,
     133             :                 x.data<T>(),
     134             :                 y.data<T>(),
     135             :                 t_beta,
     136             :                 out->data<T>());
     137             :     } else {
     138           0 :       blas.BatchedGEMM(CblasNoTrans,
     139             :                        CblasNoTrans,
     140             :                        x_dims[1],
     141             :                        y_dims[2],
     142             :                        x_dims[2],
     143             :                        t_alpha,
     144             :                        x.data<T>(),
     145             :                        y.data<T>(),
     146             :                        t_beta,
     147             :                        out->data<T>(),
     148             :                        x_dims[0],
     149           0 :                        x_dims[1] * x_dims[2],
     150           0 :                        x_dims[2] * y_dims[2]);
     151             :     }
     152             :   } else {
     153           0 :     T t_alpha = static_cast<T>(alpha);
     154           0 :     T t_beta = static_cast<T>(beta);
     155           0 :     if (x_dims[0] == 1) {
     156           0 :       blas.GEMM(CblasNoTrans,
     157             :                 CblasNoTrans,
     158             :                 x_dims[1],
     159             :                 y_dims[2],
     160             :                 x_dims[2],
     161             :                 t_alpha,
     162             :                 x.data<T>(),
     163             :                 y.data<T>(),
     164             :                 t_beta,
     165             :                 out->data<T>());
     166             :     } else {
     167           0 :       blas.BatchedGEMM(CblasNoTrans,
     168             :                        CblasNoTrans,
     169             :                        x_dims[1],
     170             :                        y_dims[2],
     171             :                        x_dims[2],
     172             :                        t_alpha,
     173             :                        x.data<T>(),
     174             :                        y.data<T>(),
     175             :                        t_beta,
     176             :                        out->data<T>(),
     177             :                        x_dims[0],
     178           0 :                        x_dims[1] * x_dims[2],
     179           0 :                        x_dims[2] * y_dims[2]);
     180             :       // x_dims[2] == y_dims[1]
     181             :     }
     182             :   }
     183           0 : }
     184             : 
     185             : }  // namespace phi

Generated by: LCOV version 1.16