Halide 13.0.2
Halide compiler and libraries
check_call_graphs.h
Go to the documentation of this file.
1#ifndef CHECK_CALL_GRAPHS_H
2#define CHECK_CALL_GRAPHS_H
3
4#include <algorithm>
5#include <assert.h>
6#include <functional>
7#include <map>
8#include <numeric>
9#include <stdio.h>
10#include <string.h>
11
12#include "Halide.h"
13
14typedef std::map<std::string, std::vector<std::string>> CallGraphs;
15
16// For each producer node, find all functions that it calls.
18public:
19 CallGraphs calls; // Caller -> vector of callees
20 std::string producer = "";
21
22private:
24
25 void visit(const Halide::Internal::ProducerConsumer *op) override {
26 if (op->is_producer) {
27 std::string old_producer = producer;
28 producer = op->name;
29 calls[producer]; // Make sure each producer is allocated a slot
30 // Group the callees of the 'produce' and 'update' together
31 op->body.accept(this);
32 producer = old_producer;
33 } else {
35 }
36 }
37
38 void visit(const Halide::Internal::Load *op) override {
40 if (!producer.empty()) {
41 assert(calls.count(producer) > 0);
42 std::vector<std::string> &callees = calls[producer];
43 if (std::find(callees.begin(), callees.end(), op->name) == callees.end()) {
44 callees.push_back(op->name);
45 }
46 }
47 }
48};
49
50// These are declared "inline" to avoid "unused function" warnings
51inline int check_call_graphs(CallGraphs &result, CallGraphs &expected) {
52 if (result.size() != expected.size()) {
53 printf("Expect %d callers instead of %d\n", (int)expected.size(), (int)result.size());
54 return -1;
55 }
56 for (auto &iter : expected) {
57 if (result.count(iter.first) == 0) {
58 printf("Expect %s to be in the call graphs\n", iter.first.c_str());
59 return -1;
60 }
61 std::vector<std::string> &expected_callees = iter.second;
62 std::vector<std::string> &result_callees = result[iter.first];
63 std::sort(expected_callees.begin(), expected_callees.end());
64 std::sort(result_callees.begin(), result_callees.end());
65 if (expected_callees != result_callees) {
66 std::string expected_str = std::accumulate(
67 expected_callees.begin(), expected_callees.end(), std::string{},
68 [](const std::string &a, const std::string &b) {
69 return a.empty() ? b : a + ", " + b;
70 });
71 std::string result_str = std::accumulate(
72 result_callees.begin(), result_callees.end(), std::string{},
73 [](const std::string &a, const std::string &b) {
74 return a.empty() ? b : a + ", " + b;
75 });
76
77 printf("Expect calless of %s to be (%s); got (%s) instead\n",
78 iter.first.c_str(), expected_str.c_str(), result_str.c_str());
79 return -1;
80 }
81 }
82 return 0;
83}
84
85template<typename T, typename F>
86inline int check_image2(const Halide::Buffer<T> &im, const F &func) {
87 for (int y = 0; y < im.height(); y++) {
88 for (int x = 0; x < im.width(); x++) {
89 T correct = func(x, y);
90 if (im(x, y) != correct) {
91 std::cout << "im(" << x << ", " << y << ") = " << im(x, y)
92 << " instead of " << correct << "\n";
93 return -1;
94 }
95 }
96 }
97 return 0;
98}
99
100template<typename T, typename F>
101inline int check_image3(const Halide::Buffer<T> &im, const F &func) {
102 for (int z = 0; z < im.channels(); z++) {
103 for (int y = 0; y < im.height(); y++) {
104 for (int x = 0; x < im.width(); x++) {
105 T correct = func(x, y, z);
106 if (im(x, y, z) != correct) {
107 std::cout << "im(" << x << ", " << y << ", " << z << ") = "
108 << im(x, y, z) << " instead of " << correct << "\n";
109 return -1;
110 }
111 }
112 }
113 }
114 return 0;
115}
116
117template<typename T, typename F>
118inline auto // SFINAE: returns int if F has arity of 2
119check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0), int()) {
120 return check_image2(im, func);
121}
122
123template<typename T, typename F>
124inline auto // SFINAE: returns int if F has arity of 3
125check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0, 0), int()) {
126 return check_image3(im, func);
127}
128
129#endif
int check_image2(const Halide::Buffer< T > &im, const F &func)
int check_call_graphs(CallGraphs &result, CallGraphs &expected)
auto check_image(const Halide::Buffer< T > &im, const F &func) -> decltype(std::declval< F >()(0, 0), int())
int check_image3(const Halide::Buffer< T > &im, const F &func)
std::map< std::string, std::vector< std::string > > CallGraphs
std::string producer
CallGraphs calls
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Buffer.h:115
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
virtual void visit(const IntImm *)
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:190
Load a value from a named symbol if predicate is true.
Definition: IR.h:199
std::string name
Definition: IR.h:200
This node is a helpful annotation to do with permissions.
Definition: IR.h:297