Skip to content

Commit fad85aa

Browse files
committed
Merge branch 'dev' into dev-hardware
2 parents 31478cc + 54c2f7e commit fad85aa

File tree

33 files changed

+402
-193
lines changed

33 files changed

+402
-193
lines changed

src/02hardware/CMakeLists.txt

+4-6
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@ cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
22
project(hardware VERSION 0.0.0 LANGUAGES CXX)
33
message(STATUS "Project " ${PROJECT_NAME} " version " ${PROJECT_VERSION})
44

5-
# Source files
65
file(GLOB_RECURSE HARDWARE_SRC src/*.cc src/*.cpp)
6+
add_library(hardware STATIC ${HARDWARE_SRC})
7+
target_link_libraries(hardware PUBLIC common)
8+
target_include_directories(hardware PUBLIC include)
79

810
if(USE_CUDA)
9-
file(GLOB_RECURSE HARDWARE_CUDA_SRC src/devices/nvidia/*.cu)
11+
target_include_directories(hardware PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
1012
endif()
1113

12-
add_library(hardware STATIC ${HARDWARE_SRC} ${HARDWARE_CUDA_SRC} ${HARDWARE_BANG_SRC})
13-
target_link_libraries(hardware PUBLIC common)
14-
target_include_directories(hardware PUBLIC include)
15-
1614
file(GLOB_RECURSE HARDWARE_TEST test/*.cpp)
1715
if(HARDWARE_TEST)
1816
add_executable(hardware_test ${HARDWARE_TEST})

src/02hardware/include/hardware/device.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ namespace refactor::hardware {
5252

5353
virtual ~Device() = default;
5454
virtual Type type() const noexcept = 0;
55-
virtual void setContext() const noexcept;
55+
virtual void setContext() const;
5656

5757
Arc<Blob> malloc(size_t);
5858
Arc<Blob> absorb(Arc<Blob> &&);

src/02hardware/include/hardware/devices/nvidia.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace refactor::hardware {
88
class Nvidia final : public Device {
99
public:
1010
explicit Nvidia(int32_t card);
11-
void setContext() const noexcept final;
11+
void setContext() const final;
1212
Type type() const noexcept final {
1313
return Type::Nvidia;
1414
}

src/02hardware/src/device.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ namespace refactor::hardware {
5656
Device::Device(decltype(_card) card, decltype(_mem) mem)
5757
: _card(card), _mem(std::move(mem)) {}
5858

59-
void Device::setContext() const noexcept {}
59+
void Device::setContext() const {}
6060
auto Device::malloc(size_t size) -> Arc<Blob> {
6161
return Arc<Blob>(new Blob(this, size));
6262
}

src/02hardware/src/devices/mlu/device.cc

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
#include "functions.hh"
22
#include "hardware/devices/mlu.h"
33
#include "hardware/mem_pool.h"
4+
5+
#ifdef USE_BANG
6+
#include "cnrt.h"
47
#include "memory.hh"
58

9+
#define BANG_ASSERT(STATUS) \
10+
if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \
11+
RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \
12+
cnrtGetErrorStr(status), (int) status)); \
13+
}
14+
15+
#endif
616
namespace refactor::hardware {
717

818
static Arc<Memory> bangMemory(int32_t card) {
919
#ifdef USE_BANG
10-
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
11-
setDevice(card);
12-
auto [free, total] = getMemInfo();
20+
unsigned deviceCount;
21+
BANG_ASSERT(cnrtGetDeviceCount(&deviceCount));
22+
ASSERT(0 <= card && card < deviceCount, "Invalid card id: {}", card);
23+
BANG_ASSERT(cnrtSetDevice(card));
24+
25+
size_t free, total;
26+
BANG_ASSERT(cnrtMemGetInfo(&free, &total));
1327
auto size = std::min(free, std::max(5ul << 30, total * 4 / 5));
1428
fmt::println("initializing Cambricon MLU {}, memory {} / {}, alloc {}",
1529
card, free, total, size);
@@ -25,7 +39,9 @@ namespace refactor::hardware {
2539
Mlu::Mlu(int32_t card) : Device(card, bangMemory(card)) {}
2640

2741
void Mlu::setContext() const noexcept {
28-
setDevice(_card);
42+
#ifdef USE_BANG
43+
BANG_ASSERT(cnrtSetDevice(_card));
44+
#endif
2945
}
3046

3147
}// namespace refactor::hardware

src/02hardware/src/devices/mlu/functions.cc

-21
This file was deleted.

src/02hardware/src/devices/mlu/functions.hh

-28
This file was deleted.

src/02hardware/src/devices/mlu/memory.cc

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
#ifdef USE_BANG
2+
13
#include "memory.hh"
2-
#include "functions.hh"
4+
#include "cnrt.h"
5+
#include "common.h"
6+
7+
#define BANG_ASSERT(STATUS) \
8+
if (auto status = (STATUS); status != CNRT_RET_SUCCESS) { \
9+
RUNTIME_ERROR(fmt::format("bang failed on \"" #STATUS "\" with \"{}\" ({})", \
10+
cnrtGetErrorStr(status), (int) status)); \
11+
}
312

413
namespace refactor::hardware {
5-
#ifdef USE_BANG
14+
615
using M = MluMemory;
716

817
void *M::malloc(size_t size) {
@@ -28,6 +37,6 @@ namespace refactor::hardware {
2837
CNRT_MEM_TRANS_DIR_PEER2PEER));
2938
return dst;
3039
}
31-
#endif
3240

3341
}// namespace refactor::hardware
42+
#endif
+28-9
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,50 @@
11
#include "functions.cuh"
22
#include "hardware/devices/nvidia.h"
33
#include "hardware/mem_pool.h"
4-
#include "memory.cuh"
4+
5+
#ifdef USE_CUDA
6+
#include "memory.hh"
7+
#include <cuda_runtime.h>
8+
9+
#define CUDA_ASSERT(STATUS) \
10+
if (auto status = (STATUS); status != cudaSuccess) { \
11+
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
12+
cudaGetErrorString(status), (int) status)); \
13+
}
14+
#endif
515

616
namespace refactor::hardware {
717

818
static Arc<Memory> cudaMemory(int32_t card) {
919
#ifdef USE_CUDA
10-
ASSERT(0 <= card && card < getDeviceCount(), "Invalid card id: {}", card);
11-
setDevice(card);
12-
auto [free, total] = getMemInfo();
20+
int deviceCount;
21+
CUDA_ASSERT(cudaGetDeviceCount(&deviceCount));
22+
ASSERT(0 <= card && card < deviceCount, "Invalid card id: {}", card);
23+
CUDA_ASSERT(cudaSetDevice(card));
24+
25+
size_t free, total;
26+
CUDA_ASSERT(cudaMemGetInfo(&free, &total));
1327
auto size = std::min(free, std::max(5ul << 30, total * 4 / 5));
14-
fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}",
15-
card, free, total, size);
28+
cudaDeviceProp prop;
29+
CUDA_ASSERT(cudaGetDeviceProperties(&prop, 0));
30+
size_t alignment = prop.textureAlignment;
31+
fmt::println("initializing Nvidia GPU {}, memory {} / {}, alloc {}, alignment {}",
32+
card, free, total, size, alignment);
1633
return std::make_shared<MemPool>(
1734
std::make_shared<NvidiaMemory>(),
1835
size,
19-
256ul);
36+
alignment);
2037
#else
2138
return nullptr;
2239
#endif
2340
}
2441

2542
Nvidia::Nvidia(int32_t card) : Device(card, cudaMemory(card)) {}
2643

27-
void Nvidia::setContext() const noexcept {
28-
setDevice(_card);
44+
void Nvidia::setContext() const {
45+
#ifdef USE_CUDA
46+
CUDA_ASSERT(cudaSetDevice(_card));
47+
#endif
2948
}
3049

3150
}// namespace refactor::hardware

src/02hardware/src/devices/nvidia/functions.cu

-19
This file was deleted.

src/02hardware/src/devices/nvidia/functions.cuh

-24
This file was deleted.

src/02hardware/src/devices/nvidia/memory.cu renamed to src/02hardware/src/devices/nvidia/memory.cc

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1-
#include "functions.cuh"
2-
#include "memory.cuh"
1+
#ifdef USE_CUDA
2+
3+
#include "memory.hh"
4+
#include "common.h"
5+
#include <cuda_runtime.h>
6+
7+
#define CUDA_ASSERT(STATUS) \
8+
if (auto status = (STATUS); status != cudaSuccess) { \
9+
RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \
10+
cudaGetErrorString(status), (int) status)); \
11+
}
312

413
namespace refactor::hardware {
514
using M = NvidiaMemory;
@@ -29,3 +38,5 @@ namespace refactor::hardware {
2938
}
3039

3140
}// namespace refactor::hardware
41+
42+
#endif

src/03runtime/include/runtime/stream.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ namespace refactor::runtime {
4242
decltype(_device));
4343

4444
decltype(_graph) const &graph() const noexcept { return _graph; }
45-
void setData(count_t, void const *, size_t);
45+
auto setData(count_t, size_t) -> Arc<hardware::Device::Blob>;
4646
void setData(count_t, Arc<hardware::Device::Blob>);
47-
bool getData(count_t, void *, size_t) const;
47+
auto getData(count_t) const -> Arc<hardware::Device::Blob>;
48+
void setData(count_t, void const *, size_t);
49+
bool copyData(count_t, void *, size_t) const;
4850
void run();
4951
auto bench(void (*sync)()) -> std::vector<std::chrono::nanoseconds>;
5052
void trace(std::function<void(count_t, void const *const *, void const *const *)>);

src/03runtime/src/stream.cc

+9-3
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,21 @@ namespace refactor::runtime {
1818
std::move(edges),
1919
} {}
2020

21+
auto Stream::setData(count_t i, size_t size) -> Arc<hardware::Device::Blob> {
22+
return _graph.edges[i].blob = _device->malloc(size);
23+
}
24+
void Stream::setData(count_t i, Arc<hardware::Device::Blob> blob) {
25+
_graph.edges[i].blob = std::move(blob);
26+
}
2127
void Stream::setData(count_t i, void const *data, size_t size) {
2228
auto blob = _device->malloc(size);
2329
blob->copyFromHost(data, size);
2430
_graph.edges[i].blob = std::move(blob);
2531
}
26-
void Stream::setData(count_t i, Arc<hardware::Device::Blob> blob) {
27-
_graph.edges[i].blob = std::move(blob);
32+
auto Stream::getData(count_t i) const -> Arc<hardware::Device::Blob> {
33+
return _graph.edges[i].blob;
2834
}
29-
bool Stream::getData(count_t i, void *data, size_t size) const {
35+
bool Stream::copyData(count_t i, void *data, size_t size) const {
3036
if (!_graph.edges[i].blob) { return false; }
3137
_graph.edges[i].blob->copyToHost(data, size);
3238
return true;

src/04kernel/include/kernel/collectors/simple_binary.h

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ namespace refactor::kernel {
1414
And,
1515
Or,
1616
Xor,
17+
Mod,
18+
Fmod,
1719
};
1820

1921
std::string_view opName(SimpleBinaryType type);

src/04kernel/src/collectors/simple_binary.cc

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ namespace refactor::kernel {
2020
CASE(And);
2121
CASE(Or);
2222
CASE(Xor);
23+
CASE(Mod);
24+
CASE(Fmod);
2325
default:
2426
UNREACHABLE();
2527
}

0 commit comments

Comments
 (0)