Line data Source code
1 : // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "paddle/phi/kernels/cum_kernel.h"
16 :
17 : #include "paddle/phi/backends/cpu/cpu_context.h"
18 : #include "paddle/phi/core/kernel_registry.h"
19 : #include "paddle/phi/kernels/funcs/eigen/common.h"
20 : #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
21 :
22 : namespace phi {
23 :
24 : template <typename Device,
25 : typename Dim,
26 : typename X,
27 : typename Out,
28 : typename Reducer>
29 : void ComputeImp(Device d,
30 : const Dim& dims,
31 : X x,
32 : Out out,
33 : int axis,
34 : bool reverse,
35 : bool exclusive,
36 : Reducer reducer) {
37 : if (!reverse) {
38 : out.reshape(dims).device(d) =
39 : x.reshape(dims).scan(axis, reducer, exclusive);
40 : } else {
41 : std::array<bool, Dim::count> rev;
42 : rev.fill(false);
43 : rev[axis] = reverse;
44 : out.reshape(dims).device(d) = x.reshape(dims)
45 : .reverse(rev)
46 : .scan(axis, reducer, exclusive)
47 : .reverse(rev);
48 : }
49 : }
50 :
51 : template <typename T, typename Context, typename Reducer>
52 : void ScanKernel(const Context& dev_ctx,
53 : const DenseTensor& x,
54 : int axis,
55 : bool flatten UNUSED,
56 : bool exclusive,
57 : bool reverse,
58 : Reducer reducer,
59 : DenseTensor* out) {
60 17 : if (out && out->numel() == 0) {
61 0 : dev_ctx.template Alloc<T>(out);
62 0 : return;
63 : }
64 : dev_ctx.template Alloc<T>(out);
65 :
66 : if (x.numel() == 1) {
67 : auto raw_dims = out->dims();
68 : phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
69 : out->Resize(raw_dims);
70 : return;
71 : }
72 : auto out_dims = out->dims();
73 :
74 : PADDLE_ENFORCE_EQ(
75 : axis < out_dims.size() && axis >= (0 - out_dims.size()),
76 : true,
77 : common::errors::OutOfRange(
78 : "Attr(axis) is out of range, It's expected "
79 : "to be in range of [-%d, %d]. But received Attr(axis) = %d.",
80 : out_dims.size(),
81 : out_dims.size() - 1,
82 : axis));
83 : if (axis < 0) {
84 : axis += out_dims.size();
85 : }
86 :
87 : int pre = 1;
88 : int post = 1;
89 : int mid = static_cast<int>(out_dims[axis]);
90 : for (int i = 0; i < axis; ++i) {
91 : pre *= static_cast<int>(out_dims[i]);
92 : }
93 : for (int i = axis + 1; i < out_dims.size(); ++i) {
94 : post *= static_cast<int>(out_dims[i]);
95 : }
96 :
97 : auto x0 = EigenVector<T>::Flatten(x);
98 : auto out0 = EigenVector<T>::Flatten(*out);
99 : auto& place = *dev_ctx.eigen_device();
100 :
101 : using IndexT = Eigen::DenseIndex;
102 : if (pre == 1) {
103 : if (post == 1) {
104 : ComputeImp(place,
105 : Eigen::DSizes<IndexT, 1>(mid),
106 : x0,
107 : out0,
108 : /* axis= */ 0,
109 : reverse,
110 : exclusive,
111 : reducer);
112 : } else {
113 : ComputeImp(place,
114 : Eigen::DSizes<IndexT, 2>(mid, post),
115 : x0,
116 : out0,
117 : /* axis= */ 0,
118 : reverse,
119 : exclusive,
120 : reducer);
121 : }
122 : } else {
123 : if (post == 1) {
124 : ComputeImp(place,
125 : Eigen::DSizes<IndexT, 2>(pre, mid),
126 : x0,
127 : out0,
128 : /* axis= */ 1,
129 : reverse,
130 : exclusive,
131 : reducer);
132 : } else {
133 : ComputeImp(place,
134 : Eigen::DSizes<IndexT, 3>(pre, mid, post),
135 : x0,
136 : out0,
137 : /* axis= */ 1,
138 : reverse,
139 : exclusive,
140 : reducer);
141 : }
142 : }
143 : }
144 :
145 : template <typename T, typename Context>
146 : void CumsumKernel(const Context& dev_ctx,
147 : const DenseTensor& x,
148 : const Scalar& axis,
149 : bool flatten,
150 : bool exclusive,
151 : bool reverse,
152 : DenseTensor* out) {
153 : using Reducer = Eigen::internal::SumReducer<T>;
154 : auto reducer = Reducer();
155 : ScanKernel<T, Context, Reducer>(
156 : dev_ctx, x, axis.to<int>(), flatten, exclusive, reverse, reducer, out);
157 : }
158 :
159 : template <typename T>
160 : struct LogSumExp {
161 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
162 : const T& b) const {
163 : auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
164 : auto ma = Eigen::internal::scalar_max_op<T>()(a, b);
165 :
166 : auto sub = Eigen::internal::scalar_difference_op<T>();
167 : auto add = Eigen::internal::scalar_sum_op<T>();
168 : auto exp = Eigen::internal::scalar_exp_op<T>();
169 : auto log1p = Eigen::internal::scalar_log1p_op<T>();
170 : auto cmp_lt =
171 : Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
172 :
173 : auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
174 : return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
175 : }
176 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
177 : const T& b) const {
178 : auto mi = Eigen::internal::pmin(a, b);
179 : auto ma = Eigen::internal::pmax(a, b);
180 : using Eigen::internal::padd;
181 : using Eigen::internal::pcmp_lt;
182 : using Eigen::internal::pexp;
183 : using Eigen::internal::plog1p;
184 : using Eigen::internal::pset1;
185 : using Eigen::internal::psub;
186 :
187 : auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
188 : return pselect(
189 : pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma, logsumexp);
190 : }
191 : };
192 :
193 : template <typename T>
194 : struct LogSumExpReducer {
195 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
196 : LogSumExp<T> logsumexp;
197 : *accum = logsumexp(*accum, t);
198 : }
199 :
200 : template <typename Packet>
201 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
202 : Packet* accum) const {
203 : LogSumExp<T> logsumexp;
204 : *accum = logsumexp.packetOp(*accum, p);
205 : }
206 :
207 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
208 : return Eigen::NumTraits<T>::lowest();
209 : }
210 :
211 : template <typename Packet>
212 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
213 : return Eigen::internal::pset1(initialize());
214 : }
215 :
216 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
217 : return accum;
218 : }
219 :
220 : template <typename Packet>
221 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
222 : finalizePacket(const Packet& vaccum) const {
223 : return vaccum;
224 : }
225 :
226 : template <typename Packet>
227 : EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
228 : finalizeBoth(const T saccum, const Packet& vaccum) const {
229 : auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
230 : auto sum_reducer = Eigen::internal::SumReducer<T>();
231 : auto exp = Eigen::internal::scalar_exp_op<T>();
232 : auto cmp_lt =
233 : Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
234 : auto log = Eigen::internal::scalar_log_op<T>();
235 : auto add = Eigen::internal::scalar_sum_op<T>();
236 :
237 : using Eigen::internal::pexp;
238 : using Eigen::internal::psub;
239 :
240 : // `ma = max(x1, ..., xn)`
241 : // If the max of all of the `xi` is `-infinity` then the result is
242 : // -infinity. If the max is larger than `-infinity` then it's safe to use
243 : // for normalization even if the other elements are `-infinity`.
244 : //
245 : // `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
246 : auto ma = max_reducer.finalizeBoth(saccum, vaccum);
247 : auto logsumexp = add(log(sum_reducer.finalizeBoth(
248 : exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
249 : ma);
250 : return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
251 : }
252 : };
253 :
254 : template <typename T, typename Context>
255 : void LogcumsumexpKernel(const Context& dev_ctx,
256 : const DenseTensor& x,
257 : int axis,
258 : bool flatten,
259 : bool exclusive,
260 : bool reverse,
261 : DenseTensor* out) {
262 : using Reducer = LogSumExpReducer<T>;
263 : auto reducer = Reducer();
264 : ScanKernel<T, Context, Reducer>(
265 : dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
266 : }
267 :
268 : } // namespace phi
269 :
270 : PD_REGISTER_KERNEL(cumsum,
271 : CPU,
272 : ALL_LAYOUT,
273 : phi::CumsumKernel,
274 : float,
275 : double,
276 : int16_t,
277 : int,
278 : int64_t,
279 : phi::dtype::complex<float>,
280 : phi::dtype::complex<double>) {}
281 :
282 : PD_REGISTER_KERNEL(
283 : logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
|