LCOV - code coverage report
Current view: top level - cpu - cum_kernel.cc Hit Total Coverage
Test: Diff Coverage Lines: 1 3 33.3 %
Date: 2025-05-20 07:49:27

          Line data    Source code
       1             : // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
       2             : //
       3             : // Licensed under the Apache License, Version 2.0 (the "License");
       4             : // you may not use this file except in compliance with the License.
       5             : // You may obtain a copy of the License at
       6             : //
       7             : //     http://www.apache.org/licenses/LICENSE-2.0
       8             : //
       9             : // Unless required by applicable law or agreed to in writing, software
      10             : // distributed under the License is distributed on an "AS IS" BASIS,
      11             : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
      12             : // See the License for the specific language governing permissions and
      13             : // limitations under the License.
      14             : 
      15             : #include "paddle/phi/kernels/cum_kernel.h"
      16             : 
      17             : #include "paddle/phi/backends/cpu/cpu_context.h"
      18             : #include "paddle/phi/core/kernel_registry.h"
      19             : #include "paddle/phi/kernels/funcs/eigen/common.h"
      20             : #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
      21             : 
      22             : namespace phi {
      23             : 
      24             : template <typename Device,
      25             :           typename Dim,
      26             :           typename X,
      27             :           typename Out,
      28             :           typename Reducer>
      29             : void ComputeImp(Device d,
      30             :                 const Dim& dims,
      31             :                 X x,
      32             :                 Out out,
      33             :                 int axis,
      34             :                 bool reverse,
      35             :                 bool exclusive,
      36             :                 Reducer reducer) {
      37             :   if (!reverse) {
      38             :     out.reshape(dims).device(d) =
      39             :         x.reshape(dims).scan(axis, reducer, exclusive);
      40             :   } else {
      41             :     std::array<bool, Dim::count> rev;
      42             :     rev.fill(false);
      43             :     rev[axis] = reverse;
      44             :     out.reshape(dims).device(d) = x.reshape(dims)
      45             :                                       .reverse(rev)
      46             :                                       .scan(axis, reducer, exclusive)
      47             :                                       .reverse(rev);
      48             :   }
      49             : }
      50             : 
      51             : template <typename T, typename Context, typename Reducer>
      52             : void ScanKernel(const Context& dev_ctx,
      53             :                 const DenseTensor& x,
      54             :                 int axis,
      55             :                 bool flatten UNUSED,
      56             :                 bool exclusive,
      57             :                 bool reverse,
      58             :                 Reducer reducer,
      59             :                 DenseTensor* out) {
      60          17 :   if (out && out->numel() == 0) {
      61           0 :     dev_ctx.template Alloc<T>(out);
      62           0 :     return;
      63             :   }
      64             :   dev_ctx.template Alloc<T>(out);
      65             : 
      66             :   if (x.numel() == 1) {
      67             :     auto raw_dims = out->dims();
      68             :     phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
      69             :     out->Resize(raw_dims);
      70             :     return;
      71             :   }
      72             :   auto out_dims = out->dims();
      73             : 
      74             :   PADDLE_ENFORCE_EQ(
      75             :       axis < out_dims.size() && axis >= (0 - out_dims.size()),
      76             :       true,
      77             :       common::errors::OutOfRange(
      78             :           "Attr(axis) is out of range, It's expected "
      79             :           "to be in range of [-%d, %d]. But received Attr(axis) = %d.",
      80             :           out_dims.size(),
      81             :           out_dims.size() - 1,
      82             :           axis));
      83             :   if (axis < 0) {
      84             :     axis += out_dims.size();
      85             :   }
      86             : 
      87             :   int pre = 1;
      88             :   int post = 1;
      89             :   int mid = static_cast<int>(out_dims[axis]);
      90             :   for (int i = 0; i < axis; ++i) {
      91             :     pre *= static_cast<int>(out_dims[i]);
      92             :   }
      93             :   for (int i = axis + 1; i < out_dims.size(); ++i) {
      94             :     post *= static_cast<int>(out_dims[i]);
      95             :   }
      96             : 
      97             :   auto x0 = EigenVector<T>::Flatten(x);
      98             :   auto out0 = EigenVector<T>::Flatten(*out);
      99             :   auto& place = *dev_ctx.eigen_device();
     100             : 
     101             :   using IndexT = Eigen::DenseIndex;
     102             :   if (pre == 1) {
     103             :     if (post == 1) {
     104             :       ComputeImp(place,
     105             :                  Eigen::DSizes<IndexT, 1>(mid),
     106             :                  x0,
     107             :                  out0,
     108             :                  /* axis= */ 0,
     109             :                  reverse,
     110             :                  exclusive,
     111             :                  reducer);
     112             :     } else {
     113             :       ComputeImp(place,
     114             :                  Eigen::DSizes<IndexT, 2>(mid, post),
     115             :                  x0,
     116             :                  out0,
     117             :                  /* axis= */ 0,
     118             :                  reverse,
     119             :                  exclusive,
     120             :                  reducer);
     121             :     }
     122             :   } else {
     123             :     if (post == 1) {
     124             :       ComputeImp(place,
     125             :                  Eigen::DSizes<IndexT, 2>(pre, mid),
     126             :                  x0,
     127             :                  out0,
     128             :                  /* axis= */ 1,
     129             :                  reverse,
     130             :                  exclusive,
     131             :                  reducer);
     132             :     } else {
     133             :       ComputeImp(place,
     134             :                  Eigen::DSizes<IndexT, 3>(pre, mid, post),
     135             :                  x0,
     136             :                  out0,
     137             :                  /* axis= */ 1,
     138             :                  reverse,
     139             :                  exclusive,
     140             :                  reducer);
     141             :     }
     142             :   }
     143             : }
     144             : 
     145             : template <typename T, typename Context>
     146             : void CumsumKernel(const Context& dev_ctx,
     147             :                   const DenseTensor& x,
     148             :                   const Scalar& axis,
     149             :                   bool flatten,
     150             :                   bool exclusive,
     151             :                   bool reverse,
     152             :                   DenseTensor* out) {
     153             :   using Reducer = Eigen::internal::SumReducer<T>;
     154             :   auto reducer = Reducer();
     155             :   ScanKernel<T, Context, Reducer>(
     156             :       dev_ctx, x, axis.to<int>(), flatten, exclusive, reverse, reducer, out);
     157             : }
     158             : 
     159             : template <typename T>
     160             : struct LogSumExp {
     161             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
     162             :                                                      const T& b) const {
     163             :     auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
     164             :     auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
     165             : 
     166             :     auto sub = Eigen::internal::scalar_difference_op<T>();
     167             :     auto add = Eigen::internal::scalar_sum_op<T>();
     168             :     auto exp = Eigen::internal::scalar_exp_op<T>();
     169             :     auto log1p = Eigen::internal::scalar_log1p_op<T>();
     170             :     auto cmp_lt =
     171             :         Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
     172             : 
     173             :     auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
     174             :     return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
     175             :   }
     176             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
     177             :                                                    const T& b) const {
     178             :     auto mi = Eigen::internal::pmin(a, b);
     179             :     auto ma = Eigen::internal::pmax(a, b);
     180             :     using Eigen::internal::padd;
     181             :     using Eigen::internal::pcmp_lt;
     182             :     using Eigen::internal::pexp;
     183             :     using Eigen::internal::plog1p;
     184             :     using Eigen::internal::pset1;
     185             :     using Eigen::internal::psub;
     186             : 
     187             :     auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
     188             :     return pselect(
     189             :         pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma, logsumexp);
     190             :   }
     191             : };
     192             : 
     193             : template <typename T>
     194             : struct LogSumExpReducer {
     195             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
     196             :     LogSumExp<T> logsumexp;
     197             :     *accum = logsumexp(*accum, t);
     198             :   }
     199             : 
     200             :   template <typename Packet>
     201             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
     202             :                                                           Packet* accum) const {
     203             :     LogSumExp<T> logsumexp;
     204             :     *accum = logsumexp.packetOp(*accum, p);
     205             :   }
     206             : 
     207             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
     208             :     return Eigen::NumTraits<T>::lowest();
     209             :   }
     210             : 
     211             :   template <typename Packet>
     212             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
     213             :     return Eigen::internal::pset1(initialize());
     214             :   }
     215             : 
     216             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
     217             :     return accum;
     218             :   }
     219             : 
     220             :   template <typename Packet>
     221             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
     222             :   finalizePacket(const Packet& vaccum) const {
     223             :     return vaccum;
     224             :   }
     225             : 
     226             :   template <typename Packet>
     227             :   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
     228             :   finalizeBoth(const T saccum, const Packet& vaccum) const {
     229             :     auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
     230             :     auto sum_reducer = Eigen::internal::SumReducer<T>();
     231             :     auto exp = Eigen::internal::scalar_exp_op<T>();
     232             :     auto cmp_lt =
     233             :         Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
     234             :     auto log = Eigen::internal::scalar_log_op<T>();
     235             :     auto add = Eigen::internal::scalar_sum_op<T>();
     236             : 
     237             :     using Eigen::internal::pexp;
     238             :     using Eigen::internal::psub;
     239             : 
     240             :     // `ma = max(x1, ..., xn)`
     241             :     // If the max of all of the `xi` is `-infinity` then the result is
     242             :     // -infinity. If the max is larger than `-infinity` then it's safe to use
     243             :     // for normalization even if the other elements are `-infinity`.
     244             :     //
     245             :     // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
     246             :     auto ma = max_reducer.finalizeBoth(saccum, vaccum);
     247             :     auto logsumexp = add(log(sum_reducer.finalizeBoth(
     248             :                              exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
     249             :                          ma);
     250             :     return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
     251             :   }
     252             : };
     253             : 
     254             : template <typename T, typename Context>
     255             : void LogcumsumexpKernel(const Context& dev_ctx,
     256             :                         const DenseTensor& x,
     257             :                         int axis,
     258             :                         bool flatten,
     259             :                         bool exclusive,
     260             :                         bool reverse,
     261             :                         DenseTensor* out) {
     262             :   using Reducer = LogSumExpReducer<T>;
     263             :   auto reducer = Reducer();
     264             :   ScanKernel<T, Context, Reducer>(
     265             :       dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
     266             : }
     267             : 
     268             : }  // namespace phi
     269             : 
     270             : PD_REGISTER_KERNEL(cumsum,
     271             :                    CPU,
     272             :                    ALL_LAYOUT,
     273             :                    phi::CumsumKernel,
     274             :                    float,
     275             :                    double,
     276             :                    int16_t,
     277             :                    int,
     278             :                    int64_t,
     279             :                    phi::dtype::complex<float>,
     280             :                    phi::dtype::complex<double>) {}
     281             : 
     282             : PD_REGISTER_KERNEL(
     283             :     logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}

Generated by: LCOV version 1.16