Skip to content

Commit 4f16fdd

Browse files
committed
feat: 搭建 Attention 在各层的基本结构
Signed-off-by: YdrMaster <[email protected]>
1 parent 205dc1a commit 4f16fdd

File tree

7 files changed

+141
-1
lines changed

7 files changed

+141
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef KERNEL_ATTENTION_H
2+
#define KERNEL_ATTENTION_H
3+
4+
#include "../collector.h"
5+
6+
namespace refactor::kernel {
7+
8+
struct AttentionCollector final : public InfoCollector {
9+
dim_t maxSeqLen;
10+
11+
AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept;
12+
13+
std::vector<KernelBox>
14+
filter(TensorRefs inputs, TensorRefs outputs) const final;
15+
};
16+
17+
}// namespace refactor::kernel
18+
19+
#endif// KERNEL_ATTENTION_H
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include "kernel/collectors/attention.h"
2+
#include "kernel/kernel.h"
3+
#include "kernel/tensor.h"
4+
// #include "../kernels/attention/cpu_kernel.hh"
5+
// #include "../kernels/attention/cuda_kernel.hh"
6+
7+
namespace refactor::kernel {
8+
9+
AttentionCollector::AttentionCollector(
10+
decltype(_target) target,
11+
decltype(maxSeqLen) maxSeqLen_) noexcept
12+
: InfoCollector(target),
13+
maxSeqLen(maxSeqLen_) {}
14+
15+
std::vector<KernelBox>
16+
AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const {
17+
std::vector<KernelBox> ans;
18+
switch (_target) {
19+
case decltype(_target)::Cpu:
20+
break;
21+
case decltype(_target)::Nvidia:
22+
break;
23+
case decltype(_target)::Mlu:
24+
break;
25+
default:
26+
UNREACHABLEX(void, "Unknown target");
27+
}
28+
return ans;
29+
}
30+
31+
}// namespace refactor::kernel
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef COMPUTATION_ATTENTION_H
2+
#define COMPUTATION_ATTENTION_H
3+
4+
#include "../operator.h"
5+
6+
namespace refactor::computation {
7+
8+
struct Attention final : public Operator {
9+
dim_t maxSeqLen;
10+
11+
constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept
12+
: Operator(), maxSeqLen(maxSeqLen_) {}
13+
14+
static size_t typeId() noexcept;
15+
size_t opTypeId() const noexcept final;
16+
std::string_view name() const noexcept final;
17+
};
18+
19+
}// namespace refactor::computation
20+
21+
#endif// COMPUTATION_ATTENTION_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#include "computation/operators/attention.h"
2+
3+
namespace refactor::computation {
4+
using Op = Attention;
5+
6+
auto Op::typeId() noexcept -> size_t {
7+
static uint8_t ID = 1;
8+
return reinterpret_cast<size_t>(&ID);
9+
}
10+
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
11+
auto Op::name() const noexcept -> std::string_view { return "Attention"; }
12+
13+
}// namespace refactor::computation
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include "computation/operators/attention.h"
2+
#include "attention.hh"
3+
#include "common.h"
4+
5+
namespace refactor::llm {
6+
using Op = Attention;
7+
8+
Op::Attention(decltype(maxSeqLen) maxSeqLen_)
9+
: Operator(), maxSeqLen(maxSeqLen_) {}
10+
11+
auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
12+
auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_();
13+
return OpBox(std::make_unique<Op>(maxSeqLen));
14+
}
15+
auto Op::typeId() -> size_t {
16+
static uint8_t ID = 1;
17+
return reinterpret_cast<size_t>(&ID);
18+
}
19+
20+
auto Op::opTypeId() const -> size_t { return typeId(); }
21+
auto Op::opTypeName() const -> std::string_view { return "llm::Attention"; }
22+
23+
auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult {
24+
TODO("");
25+
}
26+
27+
auto Op::lower(TensorRefs) const -> computation::OpBox {
28+
TODO("");
29+
}
30+
31+
}// namespace refactor::llm
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#ifndef LLM_RMS_ATTENTION_HH
2+
#define LLM_RMS_ATTENTION_HH
3+
4+
#include "frontend/operator.h"
5+
6+
namespace refactor::llm {
7+
using namespace frontend;
8+
9+
struct Attention final : public Operator {
10+
dim_t maxSeqLen;
11+
12+
explicit Attention(decltype(maxSeqLen));
13+
14+
static OpBox build(ModelContext const &, std::string_view, Attributes);
15+
static size_t typeId();
16+
17+
size_t opTypeId() const final;
18+
std::string_view opTypeName() const final;
19+
InferResult infer(TensorRefs, InferOptions const &) const final;
20+
computation::OpBox lower(TensorRefs) const final;
21+
};
22+
23+
}// namespace refactor::llm
24+
25+
#endif// LLM_RMS_ATTENTION_HH

src/08-01llm/src/operators/rms_normalization.hh

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace refactor::llm {
99
struct RmsNormalization final : public Operator {
1010
float epsilon;
1111

12-
RmsNormalization(decltype(epsilon));
12+
explicit RmsNormalization(decltype(epsilon));
1313

1414
static OpBox build(ModelContext const &, std::string_view, Attributes);
1515
static size_t typeId();

0 commit comments

Comments
 (0)