Line data Source code
1 : /* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2 :
3 : Licensed under the Apache License, Version 2.0 (the "License");
4 : you may not use this file except in compliance with the License.
5 : You may obtain a copy of the License at
6 :
7 : http://www.apache.org/licenses/LICENSE-2.0
8 :
9 : Unless required by applicable law or agreed to in writing, software
10 : distributed under the License is distributed on an "AS IS" BASIS,
11 : WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : See the License for the specific language governing permissions and
13 : limitations under the License. */
14 : #pragma once
15 :
16 : #include <type_traits>
17 :
18 : #include "glog/logging.h"
19 :
20 : #include "paddle/phi/common/amp_type_traits.h"
21 : #include "paddle/phi/kernels/baddbmm_grad_kernel.h"
22 : #include "paddle/phi/kernels/funcs/blas/blas.h"
23 : #include "paddle/phi/kernels/funcs/eigen/common.h"
24 : #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
25 : #include "paddle/phi/kernels/funcs/for_range.h"
26 :
27 : namespace phi {
28 :
29 : template <typename T>
30 : struct BCopyOrScaleFunctor {
31 : BCopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel)
32 : : scale_(scale), x_(x), output_(output), numel_(numel) {}
33 :
34 : HOSTDEVICE void operator()(int64_t idx) const {
35 : using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
36 : const MPType mp_scale = static_cast<MPType>(scale_);
37 : const MPType mp_x = static_cast<MPType>(x_[idx]);
38 : output_[idx] = static_cast<T>(mp_scale * mp_x);
39 : }
40 :
41 : private:
42 : const float scale_;
43 : const T* x_;
44 : T* output_;
45 : int64_t numel_;
46 : };
47 :
48 : template <typename T,
49 : size_t D,
50 : int MajorType = Eigen::RowMajor,
51 : typename IndexType = Eigen::DenseIndex>
52 : using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
53 :
54 : using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
55 : using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
56 : using Array3 = Eigen::DSizes<Eigen::DenseIndex, 3>;
57 :
58 : template <typename T, typename Context>
59 0 : void BaddbmmGradKernel(const Context& dev_ctx,
60 : const DenseTensor& input,
61 : const DenseTensor& x,
62 : const DenseTensor& y,
63 : const DenseTensor& out_grad,
64 : float alpha,
65 : float beta,
66 : DenseTensor* input_grad,
67 : DenseTensor* x_grad,
68 : DenseTensor* y_grad) {
69 : using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
70 0 : bool is_float16_or_bfloat16 = false;
71 : if (std::is_same<T, phi::dtype::float16>::value ||
72 : std::is_same<T, phi::dtype::bfloat16>::value) {
73 : is_float16_or_bfloat16 = true;
74 : }
75 :
76 0 : auto in_dims = input.dims();
77 0 : int total_elems = 0;
78 :
79 0 : VLOG(3) << "alpha: " << alpha << " beta: " << beta;
80 :
81 0 : if (input_grad != nullptr) {
82 0 : input_grad->set_lod(out_grad.lod());
83 : }
84 0 : if (x_grad != nullptr) {
85 0 : x_grad->set_lod(x.lod());
86 : }
87 0 : if (y_grad != nullptr) {
88 0 : y_grad->set_lod(y.lod());
89 : }
90 :
91 0 : auto blas = funcs::GetBlas<Context, T>(dev_ctx);
92 0 : auto mt_blas = funcs::GetBlas<Context, MPType>(dev_ctx);
93 0 : if (input_grad) {
94 0 : dev_ctx.template Alloc<T>(input_grad);
95 0 : total_elems = in_dims[0] * in_dims[1] * in_dims[2];
96 0 : auto& place = *dev_ctx.eigen_device();
97 0 : auto eigen_dout = PhiEigenTensor<T, 3>::From(out_grad);
98 0 : auto eigen_dinput = PhiEigenTensor<T, 3>::From(*input_grad);
99 :
100 0 : bool batch_compress = in_dims[0] != out_grad.dims()[0];
101 0 : bool row_compress = in_dims[1] != out_grad.dims()[1];
102 0 : bool col_compress = in_dims[2] != out_grad.dims()[2];
103 0 : auto eigen_dinput_shape = Array3(
104 0 : input_grad->dims()[0], input_grad->dims()[1], input_grad->dims()[2]);
105 :
106 0 : if (batch_compress && row_compress && col_compress) {
107 : if (!is_float16_or_bfloat16) {
108 0 : eigen_dinput.device(place) =
109 0 : eigen_dout.sum().eval().reshape(eigen_dinput_shape);
110 : } else {
111 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
112 : .sum()
113 : .eval()
114 : .reshape(eigen_dinput_shape)
115 : .template cast<T>();
116 : }
117 0 : } else if (batch_compress && row_compress) {
118 : if (!is_float16_or_bfloat16) {
119 0 : eigen_dinput.device(place) =
120 0 : eigen_dout.sum(Array2(0, 1)).eval().reshape(eigen_dinput_shape);
121 : } else {
122 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
123 : .sum(Array2(0, 1))
124 : .eval()
125 : .reshape(eigen_dinput_shape)
126 : .template cast<T>();
127 : }
128 0 : } else if (batch_compress && col_compress) {
129 : if (!is_float16_or_bfloat16) {
130 0 : eigen_dinput.device(place) =
131 0 : eigen_dout.sum(Array2(0, 2)).eval().reshape(eigen_dinput_shape);
132 : } else {
133 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
134 : .sum(Array2(0, 2))
135 : .eval()
136 : .reshape(eigen_dinput_shape)
137 : .template cast<T>();
138 : }
139 0 : } else if (row_compress && col_compress) {
140 : if (!is_float16_or_bfloat16) {
141 0 : eigen_dinput.device(place) =
142 0 : eigen_dout.sum(Array2(1, 2)).eval().reshape(eigen_dinput_shape);
143 : } else {
144 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
145 : .sum(Array2(1, 2))
146 : .eval()
147 : .reshape(eigen_dinput_shape)
148 : .template cast<T>();
149 : }
150 0 : } else if (batch_compress) {
151 : if (!is_float16_or_bfloat16) {
152 0 : eigen_dinput.device(place) =
153 0 : eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape);
154 : } else {
155 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
156 : .sum(Array1(0))
157 : .eval()
158 : .reshape(eigen_dinput_shape)
159 : .template cast<T>();
160 : }
161 0 : } else if (row_compress) {
162 : if (!is_float16_or_bfloat16) {
163 0 : eigen_dinput.device(place) =
164 0 : eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape);
165 : } else {
166 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
167 : .sum(Array1(1))
168 : .eval()
169 : .reshape(eigen_dinput_shape)
170 : .template cast<T>();
171 : }
172 0 : } else if (col_compress) {
173 : if (!is_float16_or_bfloat16) {
174 0 : eigen_dinput.device(place) =
175 0 : eigen_dout.sum(Array1(2)).eval().reshape(eigen_dinput_shape);
176 : } else {
177 : eigen_dinput.device(place) = eigen_dout.template cast<MPType>()
178 : .sum(Array1(2))
179 : .eval()
180 : .reshape(eigen_dinput_shape)
181 : .template cast<T>();
182 : }
183 : } else {
184 : // The VCOPY does not support the float16, bfloat16
185 : if (!is_float16_or_bfloat16) {
186 0 : mt_blas.VCOPY(
187 : total_elems, out_grad.data<MPType>(), input_grad->data<MPType>());
188 : } else {
189 : phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
190 : BCopyOrScaleFunctor<T> functor(
191 : 1, out_grad.data<T>(), input_grad->data<T>(), total_elems);
192 : for_range(functor);
193 : }
194 : }
195 :
196 : // The SCAL does not support the float16, bfloat16
197 : if (!is_float16_or_bfloat16) {
198 0 : mt_blas.SCAL(total_elems, beta, input_grad->data<MPType>());
199 : } else {
200 : phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
201 : BCopyOrScaleFunctor<T> functor(
202 : beta, input_grad->data<T>(), input_grad->data<T>(), total_elems);
203 : for_range(functor);
204 : }
205 : }
206 0 : if (x_grad) {
207 0 : dev_ctx.template Alloc<T>(x_grad);
208 0 : total_elems = x.dims()[0] * x.dims()[1] * x.dims()[2];
209 : // x_grad = out_grad * y'. x_grad: B x M x K, out_grad : B x M x N, y : B x
210 : // K x N
211 0 : for (int i = 0; i < x.dims()[0]; ++i) {
212 0 : auto out_grad_slice = out_grad.Slice(i, i + 1);
213 0 : auto y_slice = y.Slice(i, i + 1);
214 0 : auto x_grad_slice = x_grad->Slice(i, i + 1);
215 0 : auto x_grad_dims = x_grad_slice.dims();
216 :
217 0 : x_grad_slice.Resize({x_grad_dims[1], x_grad_dims[2]});
218 0 : y_slice.Resize({y_slice.dims()[1], y_slice.dims()[2]});
219 0 : out_grad_slice.Resize(
220 0 : {out_grad_slice.dims()[1], out_grad_slice.dims()[2]});
221 0 : blas.MatMul(out_grad_slice, false, y_slice, true, &x_grad_slice);
222 : }
223 : if (!is_float16_or_bfloat16) {
224 0 : mt_blas.SCAL(total_elems, alpha, x_grad->data<MPType>());
225 : } else {
226 : phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
227 : BCopyOrScaleFunctor<T> functor(
228 : alpha, x_grad->data<T>(), x_grad->data<T>(), total_elems);
229 : for_range(functor);
230 : }
231 : }
232 0 : if (y_grad) {
233 0 : dev_ctx.template Alloc<T>(y_grad);
234 0 : total_elems = y.dims()[0] * y.dims()[1] * y.dims()[2];
235 : // y_grad = x' * out_grad. y_grad: B x K x N, out_grad : B x M x N, x : B x
236 : // M x K
237 0 : for (int i = 0; i < x.dims()[0]; ++i) {
238 0 : auto out_grad_slice = out_grad.Slice(i, i + 1);
239 0 : auto x_slice = x.Slice(i, i + 1);
240 0 : auto y_grad_slice = y_grad->Slice(i, i + 1);
241 0 : out_grad_slice.Resize(
242 0 : {out_grad_slice.dims()[1], out_grad_slice.dims()[2]});
243 0 : x_slice.Resize({x_slice.dims()[1], x_slice.dims()[2]});
244 0 : y_grad_slice.Resize({y_grad_slice.dims()[1], y_grad_slice.dims()[2]});
245 0 : blas.MatMul(x_slice, true, out_grad_slice, false, &y_grad_slice);
246 : }
247 : if (!is_float16_or_bfloat16) {
248 0 : mt_blas.SCAL(total_elems, alpha, y_grad->data<MPType>());
249 : } else {
250 : phi::funcs::ForRange<Context> for_range(dev_ctx, total_elems);
251 : BCopyOrScaleFunctor<T> functor(
252 : alpha, y_grad->data<T>(), y_grad->data<T>(), total_elems);
253 : for_range(functor);
254 : }
255 : }
256 0 : }
257 :
258 : } // namespace phi
|