File tree 7 files changed +141
-1
lines changed
include/kernel/collectors
include/computation/operators
7 files changed +141
-1
lines changed Original file line number Diff line number Diff line change
1
+ #ifndef KERNEL_ATTENTION_H
2
+ #define KERNEL_ATTENTION_H
3
+
4
+ #include " ../collector.h"
5
+
6
+ namespace refactor ::kernel {
7
+
8
+ struct AttentionCollector final : public InfoCollector {
9
+ dim_t maxSeqLen;
10
+
11
+ AttentionCollector (decltype(_target), decltype(maxSeqLen)) noexcept ;
12
+
13
+ std::vector<KernelBox>
14
+ filter (TensorRefs inputs, TensorRefs outputs) const final ;
15
+ };
16
+
17
+ }// namespace refactor::kernel
18
+
19
+ #endif // KERNEL_ATTENTION_H
Original file line number Diff line number Diff line change
1
+ #include " kernel/collectors/attention.h"
2
+ #include " kernel/kernel.h"
3
+ #include " kernel/tensor.h"
4
+ // #include "../kernels/attention/cpu_kernel.hh"
5
+ // #include "../kernels/attention/cuda_kernel.hh"
6
+
7
+ namespace refactor ::kernel {
8
+
9
+ AttentionCollector::AttentionCollector (
10
+ decltype (_target) target,
11
+ decltype(maxSeqLen) maxSeqLen_) noexcept
12
+ : InfoCollector(target),
13
+ maxSeqLen(maxSeqLen_) {}
14
+
15
+ std::vector<KernelBox>
16
+ AttentionCollector::filter (TensorRefs inputs, TensorRefs outputs) const {
17
+ std::vector<KernelBox> ans;
18
+ switch (_target) {
19
+ case decltype (_target)::Cpu:
20
+ break ;
21
+ case decltype (_target)::Nvidia:
22
+ break ;
23
+ case decltype (_target)::Mlu:
24
+ break ;
25
+ default :
26
+ UNREACHABLEX (void , " Unknown target" );
27
+ }
28
+ return ans;
29
+ }
30
+
31
+ }// namespace refactor::kernel
Original file line number Diff line number Diff line change
1
+ #ifndef COMPUTATION_ATTENTION_H
2
+ #define COMPUTATION_ATTENTION_H
3
+
4
+ #include " ../operator.h"
5
+
6
+ namespace refactor ::computation {
7
+
8
+ struct Attention final : public Operator {
9
+ dim_t maxSeqLen;
10
+
11
+ constexpr Attention (decltype(maxSeqLen) maxSeqLen_) noexcept
12
+ : Operator(), maxSeqLen(maxSeqLen_) {}
13
+
14
+ static size_t typeId () noexcept ;
15
+ size_t opTypeId () const noexcept final ;
16
+ std::string_view name () const noexcept final ;
17
+ };
18
+
19
+ }// namespace refactor::computation
20
+
21
+ #endif // COMPUTATION_ATTENTION_H
Original file line number Diff line number Diff line change
1
+ #include " computation/operators/attention.h"
2
+
3
+ namespace refactor ::computation {
4
+ using Op = Attention;
5
+
6
+ auto Op::typeId () noexcept -> size_t {
7
+ static uint8_t ID = 1 ;
8
+ return reinterpret_cast <size_t >(&ID);
9
+ }
10
+ auto Op::opTypeId () const noexcept -> size_t { return typeId (); }
11
+ auto Op::name () const noexcept -> std::string_view { return " Attention" ; }
12
+
13
+ }// namespace refactor::computation
Original file line number Diff line number Diff line change
1
+ #include " computation/operators/attention.h"
2
+ #include " attention.hh"
3
+ #include " common.h"
4
+
5
+ namespace refactor ::llm {
6
+ using Op = Attention;
7
+
8
+ Op::Attention (decltype(maxSeqLen) maxSeqLen_)
9
+ : Operator(), maxSeqLen(maxSeqLen_) {}
10
+
11
+ auto Op::build (ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
12
+ auto maxSeqLen = attributes.getOrInsert (" max_seq_len" , {0 }).float_ ();
13
+ return OpBox (std::make_unique<Op>(maxSeqLen));
14
+ }
15
+ auto Op::typeId () -> size_t {
16
+ static uint8_t ID = 1 ;
17
+ return reinterpret_cast <size_t >(&ID);
18
+ }
19
+
20
+ auto Op::opTypeId () const -> size_t { return typeId (); }
21
+ auto Op::opTypeName () const -> std::string_view { return " llm::Attention" ; }
22
+
23
+ auto Op::infer (TensorRefs inputs, InferOptions const &) const -> InferResult {
24
+ TODO (" " );
25
+ }
26
+
27
+ auto Op::lower (TensorRefs) const -> computation::OpBox {
28
+ TODO (" " );
29
+ }
30
+
31
+ }// namespace refactor::llm
Original file line number Diff line number Diff line change
1
+ #ifndef LLM_RMS_ATTENTION_HH
2
+ #define LLM_RMS_ATTENTION_HH
3
+
4
+ #include " frontend/operator.h"
5
+
6
+ namespace refactor ::llm {
7
+ using namespace frontend ;
8
+
9
+ struct Attention final : public Operator {
10
+ dim_t maxSeqLen;
11
+
12
+ explicit Attention (decltype(maxSeqLen));
13
+
14
+ static OpBox build (ModelContext const &, std::string_view, Attributes);
15
+ static size_t typeId ();
16
+
17
+ size_t opTypeId () const final ;
18
+ std::string_view opTypeName () const final ;
19
+ InferResult infer (TensorRefs, InferOptions const &) const final ;
20
+ computation::OpBox lower (TensorRefs) const final ;
21
+ };
22
+
23
+ }// namespace refactor::llm
24
+
25
+ #endif // LLM_RMS_ATTENTION_HH
Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ namespace refactor::llm {
9
9
struct RmsNormalization final : public Operator {
10
10
float epsilon;
11
11
12
- RmsNormalization (decltype(epsilon));
12
+ explicit RmsNormalization (decltype(epsilon));
13
13
14
14
static OpBox build (ModelContext const &, std::string_view, Attributes);
15
15
static size_t typeId ();
You can’t perform that action at this time.
0 commit comments