Halide 13.0.2
Halide compiler and libraries
gpu_context_common.h
Go to the documentation of this file.
1#include "printer.h"
2#include "scoped_mutex_lock.h"
3
4namespace Halide {
5namespace Internal {
6
7template<typename ContextT, typename ModuleStateT>
9 struct CachedCompilation {
10 ContextT context{};
11 ModuleStateT module_state{};
12 uint32_t kernel_id{};
13 uint32_t use_count{0};
14
15 CachedCompilation(ContextT context, ModuleStateT module_state,
16 uint32_t kernel_id, uint32_t use_count)
17 : context(context), module_state(module_state),
18 kernel_id(kernel_id), use_count(use_count) {
19 }
20 };
21
22 halide_mutex mutex;
23
24 static constexpr float kLoadFactor{.5f};
25 static constexpr int kInitialTableBits{7};
26 int log2_compilations_size{0}; // number of bits in index into compilations table.
27 CachedCompilation *compilations{nullptr};
28 int count{0};
29
30 static constexpr uint32_t kInvalidId{0};
31 static constexpr uint32_t kDeletedId{1};
32
33 uint32_t unique_id{2}; // zero is an invalid id
34
35public:
36 static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits) {
37 uintptr_t addr = (uintptr_t)context + id;
38 // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
39 // in hexadecimal.
40 if (sizeof(uintptr_t) >= 8) {
41 return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
42 } else {
43 return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
44 }
45 }
46
47 HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
48 if (log2_compilations_size == 0) {
49 if (!resize_table(kInitialTableBits)) {
50 return false;
51 }
52 }
53 if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
54 if (!resize_table(log2_compilations_size + 1)) {
55 return false;
56 }
57 }
58 count += 1;
59 uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
60 for (int i = 0; i < (1 << log2_compilations_size); i++) {
61 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
62 if (compilations[effective_index].kernel_id <= kDeletedId) {
63 compilations[effective_index] = entry;
64 return true;
65 }
66 }
67 // This is a logic error that should never occur. It means the table is
68 // full, but it should have been resized.
69 halide_assert(nullptr, false);
70 return false;
71 }
72
74 ModuleStateT *&module_state, int increment) {
75 if (log2_compilations_size == 0) {
76 return false;
77 }
78 uintptr_t index = kernel_hash(context, id, log2_compilations_size);
79 for (int i = 0; i < (1 << log2_compilations_size); i++) {
80 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
81
82 if (compilations[effective_index].kernel_id == kInvalidId) {
83 return false;
84 }
85 if (compilations[effective_index].context == context &&
86 compilations[effective_index].kernel_id == id) {
87 module_state = &compilations[effective_index].module_state;
88 if (increment != 0) {
89 compilations[effective_index].use_count += increment;
90 }
91 return true;
92 }
93 }
94 return false;
95 }
96
97 HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
98 ScopedMutexLock lock_guard(&mutex);
99 uint32_t id = (uint32_t)(uintptr_t)state_ptr;
100 ModuleStateT *mod_ptr;
101 if (find_internal(context, id, mod_ptr, 0)) {
102 module_state = *mod_ptr;
103 return true;
104 }
105 return false;
106 }
107
109 if (size_bits != log2_compilations_size) {
110 int new_size = (1 << size_bits);
111 int old_size = (1 << log2_compilations_size);
112 CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
113 if (new_table == nullptr) {
114 // signal error.
115 return false;
116 }
117 memset(new_table, 0, new_size * sizeof(CachedCompilation));
118 CachedCompilation *old_table = compilations;
119 compilations = new_table;
120 log2_compilations_size = size_bits;
121
122 if (count > 0) { // Mainly to catch empty initial table case
123 for (int32_t i = 0; i < old_size; i++) {
124 if (old_table[i].kernel_id != kInvalidId &&
125 old_table[i].kernel_id != kDeletedId) {
126 bool result = insert(old_table[i]);
127 halide_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
128 }
129 }
130 }
131 free(old_table);
132 }
133 return true;
134 }
135
136 template<typename FreeModuleT>
137 void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
138 if (count == 0) {
139 return;
140 }
141
142 for (int i = 0; i < (1 << log2_compilations_size); i++) {
143 if (compilations[i].kernel_id > kInvalidId &&
144 (all || (compilations[i].context == context)) &&
145 compilations[i].use_count == 0) {
146 debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
147 << " id " << compilations[i].kernel_id
148 << " context " << compilations[i].context << "\n";
149 f(compilations[i].module_state);
150 compilations[i].module_state = nullptr;
151 compilations[i].kernel_id = kDeletedId;
152 count--;
153 }
154 }
155 }
156
157 template<typename FreeModuleT>
158 void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
159 ScopedMutexLock lock_guard(&mutex);
160
161 release_context(user_context, false, context, f);
162 }
163
164 template<typename FreeModuleT>
165 void release_all(void *user_context, FreeModuleT &f) {
166 ScopedMutexLock lock_guard(&mutex);
167
168 release_context(user_context, true, nullptr, f);
169 // Some items may have been in use, so can't free.
170 if (count == 0) {
171 free(compilations);
172 compilations = nullptr;
173 log2_compilations_size = 0;
174 }
175 }
176
177 template<typename CompileModuleT, typename... Args>
179 ContextT context, ModuleStateT &result,
180 CompileModuleT f,
181 Args... args) {
182 ScopedMutexLock lock_guard(&mutex);
183
184 uint32_t *id_ptr = (uint32_t *)state_ptr;
185 if (*id_ptr == 0) {
186 *id_ptr = unique_id++;
187 }
188
189 ModuleStateT *mod;
190 if (find_internal(context, *id_ptr, mod, 1)) {
191 result = *mod;
192 return true;
193 }
194
195 // TODO(zvookin): figure out the calling signature here...
196 ModuleStateT compiled_module = f(args...);
197 debug(user_context) << "Caching compiled kernel: " << compiled_module
198 << " id " << *id_ptr << " context " << context << "\n";
199 if (compiled_module == nullptr) {
200 return false;
201 }
202
203 if (!insert({context, compiled_module, *id_ptr, 1})) {
204 return false;
205 }
206 result = compiled_module;
207
208 return true;
209 }
210
211 void release_hold(void *user_context, ContextT context, void *state_ptr) {
212 ModuleStateT *mod;
213 uint32_t id = (uint32_t)(uintptr_t)state_ptr;
214 bool result = find_internal(context, id, mod, -1);
215 halide_assert(user_context, result); // Value must be in cache to be released
216 }
217};
218
219} // namespace Internal
220} // namespace Halide
#define HALIDE_MUST_USE_RESULT
Definition: HalideRuntime.h:54
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state)
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits)
void release_hold(void *user_context, ContextT context, void *state_ptr)
HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry)
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr, ContextT context, ModuleStateT &result, CompileModuleT f, Args... args)
void release_all(void *user_context, FreeModuleT &f)
void delete_context(void *user_context, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool resize_table(int size_bits)
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id, ModuleStateT *&module_state, int increment)
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1089
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
void * user_context
Definition: printer.h:33
void * malloc(size_t)
signed __INT32_TYPE__ int32_t
#define ALWAYS_INLINE
void * memset(void *s, int val, size_t n)
unsigned __INT32_TYPE__ uint32_t
#define halide_assert(user_context, cond)
void free(void *)
Cross-platform mutex.