Line data Source code
1 : /* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2 :
3 : Licensed under the Apache License, Version 2.0 (the "License");
4 : you may not use this file except in compliance with the License.
5 : You may obtain a copy of the License at
6 :
7 : http://www.apache.org/licenses/LICENSE-2.0
8 :
9 : Unless required by applicable law or agreed to in writing, software
10 : distributed under the License is distributed on an "AS IS" BASIS,
11 : WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : See the License for the specific language governing permissions and
13 : limitations under the License. */
14 :
15 : #pragma once
16 :
17 : #include <type_traits>
18 :
19 : #include "glog/logging.h"
20 :
21 : #include "paddle/phi/common/amp_type_traits.h"
22 : #include "paddle/phi/kernels/baddbmm_kernel.h"
23 : #include "paddle/phi/kernels/funcs/blas/blas.h"
24 : #include "paddle/phi/kernels/funcs/eigen/common.h"
25 : #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
26 :
27 : namespace phi {
28 :
29 : template <typename T,
30 : size_t D,
31 : int MajorType = Eigen::RowMajor,
32 : typename IndexType = Eigen::DenseIndex>
33 : using PhiEigenTensor = EigenTensor<T, D, MajorType, IndexType>;
34 :
35 : using Array1 = Eigen::DSizes<Eigen::DenseIndex, 1>;
36 : using Array2 = Eigen::DSizes<Eigen::DenseIndex, 2>;
37 : using Array3 = Eigen::DSizes<Eigen::DenseIndex, 3>;
38 :
39 : template <typename T, typename Context>
40 0 : void BaddbmmKernel(const Context& dev_ctx,
41 : const DenseTensor& input,
42 : const DenseTensor& x,
43 : const DenseTensor& y,
44 : float beta,
45 : float alpha,
46 : DenseTensor* out) {
47 0 : auto input_dims = input.dims();
48 0 : auto x_dims = x.dims();
49 0 : auto y_dims = y.dims();
50 :
51 0 : DenseTensor input_3d(input);
52 0 : if (input.dims().size() == 2) {
53 0 : input_dims = {1, input.dims()[0], input.dims()[1]};
54 0 : input_3d.Resize(input_dims);
55 : }
56 :
57 : // broadcast mode check
58 0 : if (x_dims[0] != input_dims[0]) {
59 0 : PADDLE_ENFORCE_EQ(input_dims[0],
60 : 1,
61 : errors::InvalidArgument(
62 : "When x_dims[0] is not equal with input_dims[0], "
63 : "input_dims[0] must be 1 but got %s",
64 : input_dims[0]));
65 0 : PADDLE_ENFORCE_EQ(y_dims[2] == input_dims[2] || input_dims[2] == 1,
66 : true,
67 : errors::InvalidArgument(
68 : "The input tensor shape mismatch, input shape=[%s], "
69 : "x shape=[%s], y shape=[%s]",
70 : input_dims,
71 : x_dims,
72 : y_dims));
73 : }
74 0 : if (y_dims[2] != input_dims[2]) {
75 0 : PADDLE_ENFORCE_EQ(input_dims[2],
76 : 1,
77 : errors::InvalidArgument(
78 : "When y_dims[2] is not equal with input_dims[2], "
79 : "input_dims[2] must be 1 but got %s",
80 : input_dims[2]));
81 0 : PADDLE_ENFORCE_EQ(x_dims[0] == input_dims[0] || input_dims[0] == 1,
82 : true,
83 : errors::InvalidArgument(
84 : "The input tensor shape mismatch, input shape=[%s], "
85 : "x shape=[%s], y shape=[%s]",
86 : input_dims,
87 : x_dims,
88 : y_dims));
89 : }
90 0 : PADDLE_ENFORCE_EQ(
91 : x_dims[2],
92 : y_dims[1],
93 : errors::InvalidArgument(
94 : "The input tensor X's width must be equal with matrix Y' height. "
95 : "But received X's shape = [%s], Y's shape = [%s].",
96 : x_dims[2],
97 : y_dims[1]));
98 :
99 0 : dev_ctx.template Alloc<T>(out);
100 0 : auto blas = funcs::GetBlas<Context, T>(dev_ctx);
101 :
102 : // calc broadcast dim
103 0 : Array3 bcast_dims;
104 0 : bcast_dims[0] = x_dims[0] / input_dims[0];
105 0 : bcast_dims[1] = x_dims[1] / input_dims[1];
106 0 : bcast_dims[2] = y_dims[2] / input_dims[2];
107 0 : VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << ","
108 0 : << bcast_dims[2] << "]";
109 :
110 : // broadcast using eigen
111 0 : const DenseTensor& const_ref_input = input_3d;
112 0 : auto eigen_input = PhiEigenTensor<T, 3>::From(const_ref_input);
113 0 : auto eigen_out = PhiEigenTensor<T, 3>::From(*out);
114 0 : auto& place = *dev_ctx.eigen_device();
115 0 : funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 3>::Eval(
116 : place, eigen_out, eigen_input, bcast_dims);
117 :
118 : using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
119 :
120 : // special case for MPType
121 : if constexpr (std::is_same_v<MPType, float>) {
122 0 : VLOG(4) << "Function: baddbmm, Type of T: " << typeid(T).name();
123 0 : VLOG(4) << "Function: baddbmm, Type of MPType: " << typeid(MPType).name();
124 0 : float t_alpha = alpha;
125 0 : float t_beta = beta;
126 0 : if (x_dims[0] == 1) {
127 0 : blas.GEMM(CblasNoTrans,
128 : CblasNoTrans,
129 : x_dims[1],
130 : y_dims[2],
131 : x_dims[2],
132 : t_alpha,
133 : x.data<T>(),
134 : y.data<T>(),
135 : t_beta,
136 : out->data<T>());
137 : } else {
138 0 : blas.BatchedGEMM(CblasNoTrans,
139 : CblasNoTrans,
140 : x_dims[1],
141 : y_dims[2],
142 : x_dims[2],
143 : t_alpha,
144 : x.data<T>(),
145 : y.data<T>(),
146 : t_beta,
147 : out->data<T>(),
148 : x_dims[0],
149 0 : x_dims[1] * x_dims[2],
150 0 : x_dims[2] * y_dims[2]);
151 : }
152 : } else {
153 0 : T t_alpha = static_cast<T>(alpha);
154 0 : T t_beta = static_cast<T>(beta);
155 0 : if (x_dims[0] == 1) {
156 0 : blas.GEMM(CblasNoTrans,
157 : CblasNoTrans,
158 : x_dims[1],
159 : y_dims[2],
160 : x_dims[2],
161 : t_alpha,
162 : x.data<T>(),
163 : y.data<T>(),
164 : t_beta,
165 : out->data<T>());
166 : } else {
167 0 : blas.BatchedGEMM(CblasNoTrans,
168 : CblasNoTrans,
169 : x_dims[1],
170 : y_dims[2],
171 : x_dims[2],
172 : t_alpha,
173 : x.data<T>(),
174 : y.data<T>(),
175 : t_beta,
176 : out->data<T>(),
177 : x_dims[0],
178 0 : x_dims[1] * x_dims[2],
179 0 : x_dims[2] * y_dims[2]);
180 : // x_dims[2] == y_dims[1]
181 : }
182 : }
183 0 : }
184 :
185 : } // namespace phi
|