Skip to content

Commit 205dc1a

Browse files
committed
feat(llm): 注册 rms normalization
Signed-off-by: YdrMaster <[email protected]>
1 parent 2ae5902 commit 205dc1a

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

src/07onnx/src/operators.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
namespace refactor::onnx {
4343

4444
void register_() {
45+
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("onnx::" #NAME)
4546
// clang-format off
46-
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("onnx::" #NAME)
4747
REGISTER(BatchNormalization , BatchNormalization );
4848
REGISTER(Cast , Cast );
4949
REGISTER(Clip , Clip );
@@ -130,8 +130,8 @@ namespace refactor::onnx {
130130
REGISTER(Where , Where );
131131
REGISTER(HardSigmoid , HardSigmoid );
132132
REGISTER(Pad , Pad );
133-
#undef REGISTER
134133
// clang-format on
134+
#undef REGISTER
135135
}
136136

137137
}// namespace refactor::onnx

src/08-01llm/src/operators.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
#include "llm/operators.h"
22
#include "operators/mat_mul.hh"
3+
#include "operators/rms_normalization.hh"
34

45
namespace refactor::llm {
56
using namespace frontend;
67

78
void register_() {
9+
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("llm::" #NAME)
810
// clang-format off
9-
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("llm::" #NAME)
10-
REGISTER(MatMul, MatMul);
11-
#undef REGISTER
11+
REGISTER(MatMul , MatMul );
12+
REGISTER(RmsNormalization, RmsNormalization);
1213
// clang-format on
14+
#undef REGISTER
1315
}
1416

1517
}// namespace refactor::llm

src/08communication/src/operators.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@ namespace refactor::communication {
66
using namespace frontend;
77

88
void register_() {
9+
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("onnx::" #NAME)
910
// clang-format off
10-
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("onnx::" #NAME)
1111
REGISTER(AllReduceAvg , AllReduce);
1212
REGISTER(AllReduceSum , AllReduce);
1313
REGISTER(AllReduceMin , AllReduce);
1414
REGISTER(AllReduceMax , AllReduce);
1515
REGISTER(AllReduceProd, AllReduce);
1616
REGISTER(AllGather , AllGather);
17-
#undef REGISTER
1817
// clang-format on
18+
#undef REGISTER
1919
}
2020

2121
}// namespace refactor::communication

0 commit comments

Comments
 (0)