Halide 13.0.2
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134 MatcherState() noexcept {
135 }
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
161 halide_type_t scalar_type = ty;
162 if (scalar_type.lanes & MatcherState::special_values_mask) {
163 return make_const_special_expr(scalar_type);
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
179 e = FloatImm::make(scalar_type, val.u.f64);
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(e, lanes);
187 }
188 return e;
189}
190
191bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192
193// A fast version of expression equality that assumes a well-typed non-null expression tree.
195bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196 // Early out
197 return (&a == &b) ||
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
200 equal_helper(a, b));
201}
202
203// A pattern that matches a specific expression
205 struct pattern_tag {};
206
207 constexpr static uint32_t binds = 0;
208
209 // What is the weakest and strongest IR node this could possibly be
212 constexpr static bool canonical = true;
213
215
216 template<uint32_t bound>
217 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218 return equal(expr, e);
219 }
220
222 Expr make(MatcherState &state, halide_type_t type_hint) const {
223 return Expr(&expr);
224 }
225
226 constexpr static bool foldable = false;
227};
228
229inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230 s << Expr(&e.expr);
231 return s;
232}
233
234template<int i>
236 struct pattern_tag {};
237
238 constexpr static uint32_t binds = 1 << i;
239
242 constexpr static bool canonical = true;
243
244 template<uint32_t bound>
245 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247 const BaseExprNode *op = &e;
248 if (op->node_type == IRNodeType::Broadcast) {
249 op = ((const Broadcast *)op)->value.get();
250 }
251 if (op->node_type != IRNodeType::IntImm) {
252 return false;
253 }
254 int64_t value = ((const IntImm *)op)->value;
255 if (bound & binds) {
257 halide_type_t type;
258 state.get_bound_const(i, val, type);
259 return (halide_type_t)e.type == type && value == val.u.i64;
260 }
261 state.set_bound_const(i, value, e.type);
262 return true;
263 }
264
265 template<uint32_t bound>
266 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268 if (bound & binds) {
270 halide_type_t type;
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.u.i64;
273 }
274 state.set_bound_const(i, value, i64_type);
275 return true;
276 }
277
279 Expr make(MatcherState &state, halide_type_t type_hint) const {
281 halide_type_t type;
282 state.get_bound_const(i, val, type);
283 return make_const_expr(val, type);
284 }
285
286 constexpr static bool foldable = true;
287
290 state.get_bound_const(i, val, ty);
291 }
292};
293
294template<int i>
295std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296 s << "ci" << i;
297 return s;
298}
299
300template<int i>
302 struct pattern_tag {};
303
304 constexpr static uint32_t binds = 1 << i;
305
308 constexpr static bool canonical = true;
309
310 template<uint32_t bound>
311 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313 const BaseExprNode *op = &e;
314 if (op->node_type == IRNodeType::Broadcast) {
315 op = ((const Broadcast *)op)->value.get();
316 }
317 if (op->node_type != IRNodeType::UIntImm) {
318 return false;
319 }
320 uint64_t value = ((const UIntImm *)op)->value;
321 if (bound & binds) {
323 halide_type_t type;
324 state.get_bound_const(i, val, type);
325 return (halide_type_t)e.type == type && value == val.u.u64;
326 }
327 state.set_bound_const(i, value, e.type);
328 return true;
329 }
330
332 Expr make(MatcherState &state, halide_type_t type_hint) const {
334 halide_type_t type;
335 state.get_bound_const(i, val, type);
336 return make_const_expr(val, type);
337 }
338
339 constexpr static bool foldable = true;
340
342 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
343 state.get_bound_const(i, val, ty);
344 }
345};
346
347template<int i>
348std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349 s << "cu" << i;
350 return s;
351}
352
353template<int i>
355 struct pattern_tag {};
356
357 constexpr static uint32_t binds = 1 << i;
358
361 constexpr static bool canonical = true;
362
363 template<uint32_t bound>
364 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366 const BaseExprNode *op = &e;
367 if (op->node_type == IRNodeType::Broadcast) {
368 op = ((const Broadcast *)op)->value.get();
369 }
370 if (op->node_type != IRNodeType::FloatImm) {
371 return false;
372 }
373 double value = ((const FloatImm *)op)->value;
374 if (bound & binds) {
376 halide_type_t type;
377 state.get_bound_const(i, val, type);
378 return (halide_type_t)e.type == type && value == val.u.f64;
379 }
380 state.set_bound_const(i, value, e.type);
381 return true;
382 }
383
385 Expr make(MatcherState &state, halide_type_t type_hint) const {
387 halide_type_t type;
388 state.get_bound_const(i, val, type);
389 return make_const_expr(val, type);
390 }
391
392 constexpr static bool foldable = true;
393
395 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
396 state.get_bound_const(i, val, ty);
397 }
398};
399
400template<int i>
401std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402 s << "cf" << i;
403 return s;
404}
405
406// Matches and binds to any constant Expr. Does not support constant-folding.
407template<int i>
408struct WildConst {
409 struct pattern_tag {};
410
411 constexpr static uint32_t binds = 1 << i;
412
415 constexpr static bool canonical = true;
416
417 template<uint32_t bound>
418 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420 const BaseExprNode *op = &e;
421 if (op->node_type == IRNodeType::Broadcast) {
422 op = ((const Broadcast *)op)->value.get();
423 }
424 switch (op->node_type) {
426 return WildConstInt<i>().template match<bound>(e, state);
428 return WildConstUInt<i>().template match<bound>(e, state);
430 return WildConstFloat<i>().template match<bound>(e, state);
431 default:
432 return false;
433 }
434 }
435
436 template<uint32_t bound>
437 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439 return WildConstInt<i>().template match<bound>(e, state);
440 }
441
443 Expr make(MatcherState &state, halide_type_t type_hint) const {
445 halide_type_t type;
446 state.get_bound_const(i, val, type);
447 return make_const_expr(val, type);
448 }
449
450 constexpr static bool foldable = true;
451
453 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
454 state.get_bound_const(i, val, ty);
455 }
456};
457
458template<int i>
459std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460 s << "c" << i;
461 return s;
462}
463
464// Matches and binds to any Expr
465template<int i>
466struct Wild {
467 struct pattern_tag {};
468
469 constexpr static uint32_t binds = 1 << (i + 16);
470
473 constexpr static bool canonical = true;
474
475 template<uint32_t bound>
476 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477 if (bound & binds) {
478 return equal(*state.get_binding(i), e);
479 }
480 state.set_binding(i, e);
481 return true;
482 }
483
485 Expr make(MatcherState &state, halide_type_t type_hint) const {
486 return state.get_binding(i);
487 }
488
489 constexpr static bool foldable = true;
491 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
492 const auto *e = state.get_binding(i);
493 ty = e->type;
494 switch (e->node_type) {
496 val.u.u64 = ((const UIntImm *)e)->value;
497 return;
499 val.u.i64 = ((const IntImm *)e)->value;
500 return;
502 val.u.f64 = ((const FloatImm *)e)->value;
503 return;
504 default:
505 // The function is noexcept, so silent failure. You
506 // shouldn't be calling this if you haven't already
507 // checked it's going to be a constant (e.g. with
508 // is_const, or because you manually bound a constant Expr
509 // to the state).
510 val.u.u64 = 0;
511 }
512 }
513};
514
515template<int i>
516std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
517 s << "_" << i;
518 return s;
519}
520
521// Matches a specific constant or broadcast of that constant. The
522// constant must be representable as an int64_t.
524 struct pattern_tag {};
526
527 constexpr static uint32_t binds = 0;
528
531 constexpr static bool canonical = true;
532
535 : v(v) {
536 }
537
538 template<uint32_t bound>
539 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
540 const BaseExprNode *op = &e;
541 if (e.node_type == IRNodeType::Broadcast) {
542 op = ((const Broadcast *)op)->value.get();
543 }
544 switch (op->node_type) {
546 return ((const IntImm *)op)->value == (int64_t)v;
548 return ((const UIntImm *)op)->value == (uint64_t)v;
550 return ((const FloatImm *)op)->value == (double)v;
551 default:
552 return false;
553 }
554 }
555
556 template<uint32_t bound>
557 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
558 return v == val;
559 }
560
561 template<uint32_t bound>
562 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
563 return v == b.v;
564 }
565
567 Expr make(MatcherState &state, halide_type_t type_hint) const {
568 return make_const(type_hint, v);
569 }
570
571 constexpr static bool foldable = true;
572
574 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
575 // Assume type is already correct
576 switch (ty.code) {
577 case halide_type_int:
578 val.u.i64 = v;
579 break;
580 case halide_type_uint:
581 val.u.u64 = (uint64_t)v;
582 break;
585 val.u.f64 = (double)v;
586 break;
587 default:
588 // Unreachable
589 ;
590 }
591 }
592};
593
595 return t.v;
596}
597
598// Convert a provided pattern, expr, or constant int into the internal
599// representation we use in the matcher trees.
600template<typename T,
601 typename = typename std::decay<T>::type::pattern_tag>
603 return t;
604}
607 return IntLiteral{x};
608}
609
610template<typename T>
612 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
613 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
614}
615
617 return {*e.get()};
618}
619
620// Helpers to deref SpecificExprs to const BaseExprNode & rather than
621// passing them by value anywhere (incurring lots of refcounting)
622template<typename T,
623 // T must be a pattern node
624 typename = typename std::decay<T>::type::pattern_tag,
625 // But T may not be SpecificExpr
626 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
628 return t;
629}
630
633 return e.expr;
634}
635
636inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
637 s << op.v;
638 return s;
639}
640
641template<typename Op>
643
644template<typename Op>
646
647template<typename Op>
648double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
649
650constexpr bool commutative(IRNodeType t) {
651 return (t == IRNodeType::Add ||
652 t == IRNodeType::Mul ||
653 t == IRNodeType::And ||
654 t == IRNodeType::Or ||
655 t == IRNodeType::Min ||
656 t == IRNodeType::Max ||
657 t == IRNodeType::EQ ||
658 t == IRNodeType::NE);
659}
660
661// Matches one of the binary operators
662template<typename Op, typename A, typename B>
663struct BinOp {
664 struct pattern_tag {};
665 A a;
666 B b;
667
669
670 constexpr static IRNodeType min_node_type = Op::_node_type;
671 constexpr static IRNodeType max_node_type = Op::_node_type;
672
673 // For commutative bin ops, we expect the weaker IR node type on
674 // the right. That is, for the rule to be canonical it must be
675 // possible that A is at least as strong as B.
676 constexpr static bool canonical =
677 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
678
679 template<uint32_t bound>
680 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
681 if (e.node_type != Op::_node_type) {
682 return false;
683 }
684 const Op &op = (const Op &)e;
685 return (a.template match<bound>(*op.a.get(), state) &&
686 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
687 }
688
689 template<uint32_t bound, typename Op2, typename A2, typename B2>
690 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
691 return (std::is_same<Op, Op2>::value &&
692 a.template match<bound>(unwrap(op.a), state) &&
693 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
694 }
695
696 constexpr static bool foldable = A::foldable && B::foldable;
697
699 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
700 halide_scalar_value_t val_a, val_b;
701 if (std::is_same<A, IntLiteral>::value) {
702 b.make_folded_const(val_b, ty, state);
703 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
704 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
705 // Short circuit
706 val = val_b;
707 return;
708 }
709 const uint16_t l = ty.lanes;
710 a.make_folded_const(val_a, ty, state);
711 ty.lanes |= l; // Make sure the overflow bits are sticky
712 } else {
713 a.make_folded_const(val_a, ty, state);
714 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
715 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
716 // Short circuit
717 val = val_a;
718 return;
719 }
720 const uint16_t l = ty.lanes;
721 b.make_folded_const(val_b, ty, state);
722 ty.lanes |= l;
723 }
724 switch (ty.code) {
725 case halide_type_int:
726 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
727 break;
728 case halide_type_uint:
729 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
730 break;
733 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
734 break;
735 default:
736 // unreachable
737 ;
738 }
739 }
740
742 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
743 Expr ea, eb;
744 if (std::is_same<A, IntLiteral>::value) {
745 eb = b.make(state, type_hint);
746 ea = a.make(state, eb.type());
747 } else {
748 ea = a.make(state, type_hint);
749 eb = b.make(state, ea.type());
750 }
751 // We sometimes mix vectors and scalars in the rewrite rules,
752 // so insert a broadcast if necessary.
753 if (ea.type().is_vector() && !eb.type().is_vector()) {
754 eb = Broadcast::make(eb, ea.type().lanes());
755 }
756 if (eb.type().is_vector() && !ea.type().is_vector()) {
757 ea = Broadcast::make(ea, eb.type().lanes());
758 }
759 return Op::make(std::move(ea), std::move(eb));
760 }
761};
762
763template<typename Op>
765
766template<typename Op>
768
769template<typename Op>
770uint64_t constant_fold_cmp_op(double, double) noexcept;
771
772// Matches one of the comparison operators
773template<typename Op, typename A, typename B>
774struct CmpOp {
775 struct pattern_tag {};
776 A a;
777 B b;
778
780
781 constexpr static IRNodeType min_node_type = Op::_node_type;
782 constexpr static IRNodeType max_node_type = Op::_node_type;
783 constexpr static bool canonical = (A::canonical &&
784 B::canonical &&
785 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
786 (Op::_node_type != IRNodeType::GE) &&
787 (Op::_node_type != IRNodeType::GT));
788
789 template<uint32_t bound>
790 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
791 if (e.node_type != Op::_node_type) {
792 return false;
793 }
794 const Op &op = (const Op &)e;
795 return (a.template match<bound>(*op.a.get(), state) &&
796 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
797 }
798
799 template<uint32_t bound, typename Op2, typename A2, typename B2>
800 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
801 return (std::is_same<Op, Op2>::value &&
802 a.template match<bound>(unwrap(op.a), state) &&
803 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
804 }
805
806 constexpr static bool foldable = A::foldable && B::foldable;
807
809 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
810 halide_scalar_value_t val_a, val_b;
811 // If one side is an untyped const, evaluate the other side first to get a type hint.
812 if (std::is_same<A, IntLiteral>::value) {
813 b.make_folded_const(val_b, ty, state);
814 const uint16_t l = ty.lanes;
815 a.make_folded_const(val_a, ty, state);
816 ty.lanes |= l;
817 } else {
818 a.make_folded_const(val_a, ty, state);
819 const uint16_t l = ty.lanes;
820 b.make_folded_const(val_b, ty, state);
821 ty.lanes |= l;
822 }
823 switch (ty.code) {
824 case halide_type_int:
825 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
826 break;
827 case halide_type_uint:
828 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
829 break;
832 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
833 break;
834 default:
835 // unreachable
836 ;
837 }
838 ty.code = halide_type_uint;
839 ty.bits = 1;
840 }
841
843 Expr make(MatcherState &state, halide_type_t type_hint) const {
844 // If one side is an untyped const, evaluate the other side first to get a type hint.
845 Expr ea, eb;
846 if (std::is_same<A, IntLiteral>::value) {
847 eb = b.make(state, {});
848 ea = a.make(state, eb.type());
849 } else {
850 ea = a.make(state, {});
851 eb = b.make(state, ea.type());
852 }
853 // We sometimes mix vectors and scalars in the rewrite rules,
854 // so insert a broadcast if necessary.
855 if (ea.type().is_vector() && !eb.type().is_vector()) {
856 eb = Broadcast::make(eb, ea.type().lanes());
857 }
858 if (eb.type().is_vector() && !ea.type().is_vector()) {
859 ea = Broadcast::make(ea, eb.type().lanes());
860 }
861 return Op::make(std::move(ea), std::move(eb));
862 }
863};
864
865template<typename A, typename B>
866std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
867 s << "(" << op.a << " + " << op.b << ")";
868 return s;
869}
870
871template<typename A, typename B>
872std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
873 s << "(" << op.a << " - " << op.b << ")";
874 return s;
875}
876
877template<typename A, typename B>
878std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
879 s << "(" << op.a << " * " << op.b << ")";
880 return s;
881}
882
883template<typename A, typename B>
884std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
885 s << "(" << op.a << " / " << op.b << ")";
886 return s;
887}
888
889template<typename A, typename B>
890std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
891 s << "(" << op.a << " && " << op.b << ")";
892 return s;
893}
894
895template<typename A, typename B>
896std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
897 s << "(" << op.a << " || " << op.b << ")";
898 return s;
899}
900
901template<typename A, typename B>
902std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
903 s << "min(" << op.a << ", " << op.b << ")";
904 return s;
905}
906
907template<typename A, typename B>
908std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
909 s << "max(" << op.a << ", " << op.b << ")";
910 return s;
911}
912
913template<typename A, typename B>
914std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
915 s << "(" << op.a << " <= " << op.b << ")";
916 return s;
917}
918
919template<typename A, typename B>
920std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
921 s << "(" << op.a << " < " << op.b << ")";
922 return s;
923}
924
925template<typename A, typename B>
926std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
927 s << "(" << op.a << " >= " << op.b << ")";
928 return s;
929}
930
931template<typename A, typename B>
932std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
933 s << "(" << op.a << " > " << op.b << ")";
934 return s;
935}
936
937template<typename A, typename B>
938std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
939 s << "(" << op.a << " == " << op.b << ")";
940 return s;
941}
942
943template<typename A, typename B>
944std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
945 s << "(" << op.a << " != " << op.b << ")";
946 return s;
947}
948
949template<typename A, typename B>
950std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
951 s << "(" << op.a << " % " << op.b << ")";
952 return s;
953}
954
955template<typename A, typename B>
956HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
957 assert_is_lvalue_if_expr<A>();
958 assert_is_lvalue_if_expr<B>();
959 return {pattern_arg(a), pattern_arg(b)};
960}
961
962template<typename A, typename B>
963HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
964 assert_is_lvalue_if_expr<A>();
965 assert_is_lvalue_if_expr<B>();
966 return IRMatcher::operator+(a, b);
967}
968
969template<>
971 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
972 int dead_bits = 64 - t.bits;
973 // Drop the high bits then sign-extend them back
974 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
975}
976
977template<>
979 uint64_t ones = (uint64_t)(-1);
980 return (a + b) & (ones >> (64 - t.bits));
981}
982
983template<>
984HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
985 return a + b;
986}
987
988template<typename A, typename B>
989HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
990 assert_is_lvalue_if_expr<A>();
991 assert_is_lvalue_if_expr<B>();
992 return {pattern_arg(a), pattern_arg(b)};
993}
994
995template<typename A, typename B>
996HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
997 assert_is_lvalue_if_expr<A>();
998 assert_is_lvalue_if_expr<B>();
999 return IRMatcher::operator-(a, b);
1000}
1001
1002template<>
1004 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1005 // Drop the high bits then sign-extend them back
1006 int dead_bits = 64 - t.bits;
1007 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
1008}
1009
1010template<>
1012 uint64_t ones = (uint64_t)(-1);
1013 return (a - b) & (ones >> (64 - t.bits));
1014}
1015
1016template<>
1017HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
1018 return a - b;
1019}
1020
1021template<typename A, typename B>
1022HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1023 assert_is_lvalue_if_expr<A>();
1024 assert_is_lvalue_if_expr<B>();
1025 return {pattern_arg(a), pattern_arg(b)};
1026}
1027
1028template<typename A, typename B>
1029HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1030 assert_is_lvalue_if_expr<A>();
1031 assert_is_lvalue_if_expr<B>();
1032 return IRMatcher::operator*(a, b);
1033}
1034
1035template<>
1037 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1038 int dead_bits = 64 - t.bits;
1039 // Drop the high bits then sign-extend them back
1040 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1041}
1042
1043template<>
1045 uint64_t ones = (uint64_t)(-1);
1046 return (a * b) & (ones >> (64 - t.bits));
1047}
1048
1049template<>
1050HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1051 return a * b;
1052}
1053
1054template<typename A, typename B>
1055HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1056 assert_is_lvalue_if_expr<A>();
1057 assert_is_lvalue_if_expr<B>();
1058 return {pattern_arg(a), pattern_arg(b)};
1059}
1060
1061template<typename A, typename B>
1062HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1063 return IRMatcher::operator/(a, b);
1064}
1065
1066template<>
1068 return div_imp(a, b);
1069}
1070
1071template<>
1073 return div_imp(a, b);
1074}
1075
1076template<>
1077HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1078 return div_imp(a, b);
1079}
1080
1081template<typename A, typename B>
1082HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1083 assert_is_lvalue_if_expr<A>();
1084 assert_is_lvalue_if_expr<B>();
1085 return {pattern_arg(a), pattern_arg(b)};
1086}
1087
1088template<typename A, typename B>
1089HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1090 assert_is_lvalue_if_expr<A>();
1091 assert_is_lvalue_if_expr<B>();
1092 return IRMatcher::operator%(a, b);
1093}
1094
1095template<>
1097 return mod_imp(a, b);
1098}
1099
1100template<>
1102 return mod_imp(a, b);
1103}
1104
1105template<>
1106HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1107 return mod_imp(a, b);
1108}
1109
1110template<typename A, typename B>
1111HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1112 assert_is_lvalue_if_expr<A>();
1113 assert_is_lvalue_if_expr<B>();
1114 return {pattern_arg(a), pattern_arg(b)};
1115}
1116
1117template<>
1119 return std::min(a, b);
1120}
1121
1122template<>
1124 return std::min(a, b);
1125}
1126
1127template<>
1128HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1129 return std::min(a, b);
1130}
1131
1132template<typename A, typename B>
1133HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1134 assert_is_lvalue_if_expr<A>();
1135 assert_is_lvalue_if_expr<B>();
1136 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1137}
1138
1139template<>
1141 return std::max(a, b);
1142}
1143
1144template<>
1146 return std::max(a, b);
1147}
1148
1149template<>
1150HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1151 return std::max(a, b);
1152}
1153
1154template<typename A, typename B>
1155HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1156 return {pattern_arg(a), pattern_arg(b)};
1157}
1158
1159template<typename A, typename B>
1160HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1161 return IRMatcher::operator<(a, b);
1162}
1163
1164template<>
1166 return a < b;
1167}
1168
1169template<>
1171 return a < b;
1172}
1173
1174template<>
1176 return a < b;
1177}
1178
1179template<typename A, typename B>
1180HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1181 return {pattern_arg(a), pattern_arg(b)};
1182}
1183
1184template<typename A, typename B>
1185HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1186 return IRMatcher::operator>(a, b);
1187}
1188
1189template<>
1191 return a > b;
1192}
1193
1194template<>
1196 return a > b;
1197}
1198
1199template<>
1201 return a > b;
1202}
1203
1204template<typename A, typename B>
1205HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1206 return {pattern_arg(a), pattern_arg(b)};
1207}
1208
1209template<typename A, typename B>
1210HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1211 return IRMatcher::operator<=(a, b);
1212}
1213
1214template<>
1216 return a <= b;
1217}
1218
1219template<>
1221 return a <= b;
1222}
1223
1224template<>
1226 return a <= b;
1227}
1228
1229template<typename A, typename B>
1230HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1231 return {pattern_arg(a), pattern_arg(b)};
1232}
1233
1234template<typename A, typename B>
1235HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1236 return IRMatcher::operator>=(a, b);
1237}
1238
1239template<>
1241 return a >= b;
1242}
1243
1244template<>
1246 return a >= b;
1247}
1248
1249template<>
1251 return a >= b;
1252}
1253
1254template<typename A, typename B>
1255HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1256 return {pattern_arg(a), pattern_arg(b)};
1257}
1258
1259template<typename A, typename B>
1260HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1261 return IRMatcher::operator==(a, b);
1262}
1263
1264template<>
1266 return a == b;
1267}
1268
1269template<>
1271 return a == b;
1272}
1273
1274template<>
1276 return a == b;
1277}
1278
1279template<typename A, typename B>
1280HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1281 return {pattern_arg(a), pattern_arg(b)};
1282}
1283
1284template<typename A, typename B>
1285HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1286 return IRMatcher::operator!=(a, b);
1287}
1288
1289template<>
1291 return a != b;
1292}
1293
1294template<>
1296 return a != b;
1297}
1298
1299template<>
1301 return a != b;
1302}
1303
1304template<typename A, typename B>
1305HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1306 return {pattern_arg(a), pattern_arg(b)};
1307}
1308
1309template<typename A, typename B>
1310HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1311 return IRMatcher::operator||(a, b);
1312}
1313
1314template<>
1316 return (a | b) & 1;
1317}
1318
1319template<>
1321 return (a | b) & 1;
1322}
1323
1324template<>
1325HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1326 // Unreachable, as it would be a type mismatch.
1327 return 0;
1328}
1329
1330template<typename A, typename B>
1331HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1332 return {pattern_arg(a), pattern_arg(b)};
1333}
1334
1335template<typename A, typename B>
1336HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1337 return IRMatcher::operator&&(a, b);
1338}
1339
1340template<>
1342 return a & b & 1;
1343}
1344
1345template<>
1347 return a & b & 1;
1348}
1349
1350template<>
1351HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1352 // Unreachable
1353 return 0;
1354}
1355
1356constexpr inline uint32_t bitwise_or_reduce() {
1357 return 0;
1358}
1359
1360template<typename... Args>
1361constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1362 return first | bitwise_or_reduce(rest...);
1363}
1364
1365constexpr inline bool and_reduce() {
1366 return true;
1367}
1368
1369template<typename... Args>
1370constexpr bool and_reduce(bool first, Args... rest) {
1371 return first && and_reduce(rest...);
1372}
1373
1374// TODO: this can be replaced with std::min() once we require C++14 or later
1375constexpr int const_min(int a, int b) {
1376 return a < b ? a : b;
1377}
1378
1379template<typename... Args>
1380struct Intrin {
1381 struct pattern_tag {};
1383 std::tuple<Args...> args;
1384
1386
1389 constexpr static bool canonical = and_reduce((Args::canonical)...);
1390
1391 template<int i,
1392 uint32_t bound,
1393 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1394 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1395 using T = decltype(std::get<i>(args));
1396 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1397 match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1398 }
1399
1400 template<int i, uint32_t binds>
1401 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1402 return true;
1403 }
1404
1405 template<uint32_t bound>
1406 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1407 if (e.node_type != IRNodeType::Call) {
1408 return false;
1409 }
1410 const Call &c = (const Call &)e;
1411 return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state));
1412 }
1413
1414 template<int i,
1415 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1416 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1417 s << std::get<i>(args);
1418 if (i + 1 < sizeof...(Args)) {
1419 s << ", ";
1420 }
1421 print_args<i + 1>(0, s);
1422 }
1423
1424 template<int i>
1425 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1426 }
1427
1429 void print_args(std::ostream &s) const {
1430 print_args<0>(0, s);
1431 }
1432
1434 Expr make(MatcherState &state, halide_type_t type_hint) const {
1435 Expr arg0 = std::get<0>(args).make(state, type_hint);
1436 if (intrin == Call::likely) {
1437 return likely(arg0);
1438 } else if (intrin == Call::likely_if_innermost) {
1439 return likely_if_innermost(arg0);
1440 } else if (intrin == Call::abs) {
1441 return abs(arg0);
1442 }
1443
1444 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1445 if (intrin == Call::absd) {
1446 return absd(arg0, arg1);
1447 } else if (intrin == Call::widening_add) {
1448 return widening_add(arg0, arg1);
1449 } else if (intrin == Call::widening_sub) {
1450 return widening_sub(arg0, arg1);
1451 } else if (intrin == Call::widening_mul) {
1452 return widening_mul(arg0, arg1);
1453 } else if (intrin == Call::saturating_add) {
1454 return saturating_add(arg0, arg1);
1455 } else if (intrin == Call::saturating_sub) {
1456 return saturating_sub(arg0, arg1);
1457 } else if (intrin == Call::halving_add) {
1458 return halving_add(arg0, arg1);
1459 } else if (intrin == Call::halving_sub) {
1460 return halving_sub(arg0, arg1);
1461 } else if (intrin == Call::rounding_halving_add) {
1462 return rounding_halving_add(arg0, arg1);
1463 } else if (intrin == Call::rounding_halving_sub) {
1464 return rounding_halving_sub(arg0, arg1);
1465 } else if (intrin == Call::shift_left) {
1466 return arg0 << arg1;
1467 } else if (intrin == Call::shift_right) {
1468 return arg0 >> arg1;
1469 } else if (intrin == Call::rounding_shift_left) {
1470 return rounding_shift_left(arg0, arg1);
1471 } else if (intrin == Call::rounding_shift_right) {
1472 return rounding_shift_right(arg0, arg1);
1473 }
1474
1475 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1477 return mul_shift_right(arg0, arg1, arg2);
1478 } else if (intrin == Call::rounding_mul_shift_right) {
1479 return rounding_mul_shift_right(arg0, arg1, arg2);
1480 }
1481
1482 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1483 return Expr();
1484 }
1485
1486 constexpr static bool foldable = false;
1487
1490 : intrin(intrin), args(args...) {
1491 }
1492};
1493
1494template<typename... Args>
1495std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1496 s << op.intrin << "(";
1497 op.print_args(s);
1498 s << ")";
1499 return s;
1500}
1501
1502template<typename... Args>
1503HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1504 return {intrinsic_op, pattern_arg(args)...};
1505}
1506
1507template<typename A, typename B>
1508auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1510}
1511template<typename A, typename B>
1512auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1514}
1515template<typename A, typename B>
1516auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1518}
1519template<typename A, typename B>
1520auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1522}
1523template<typename A, typename B>
1524auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1526}
1527template<typename A, typename B>
1528auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1529 return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1530}
1531template<typename A, typename B>
1532auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1533 return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1534}
1535template<typename A, typename B>
1536auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1538}
1539template<typename A, typename B>
1540auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1542}
1543template<typename A, typename B>
1544auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1545 return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1546}
1547template<typename A, typename B>
1548auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1549 return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1550}
1551template<typename A, typename B>
1552auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1554}
1555template<typename A, typename B>
1556auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1558}
1559template<typename A, typename B, typename C>
1560auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1562}
1563template<typename A, typename B, typename C>
1564auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1566}
1567
1568template<typename A>
1569struct NotOp {
1570 struct pattern_tag {};
1571 A a;
1572
1574
1577 constexpr static bool canonical = A::canonical;
1578
1579 template<uint32_t bound>
1580 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1581 if (e.node_type != IRNodeType::Not) {
1582 return false;
1583 }
1584 const Not &op = (const Not &)e;
1585 return (a.template match<bound>(*op.a.get(), state));
1586 }
1587
1588 template<uint32_t bound, typename A2>
1589 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1590 return a.template match<bound>(unwrap(op.a), state);
1591 }
1592
1594 Expr make(MatcherState &state, halide_type_t type_hint) const {
1595 return Not::make(a.make(state, type_hint));
1596 }
1597
1598 constexpr static bool foldable = A::foldable;
1599
1600 template<typename A1 = A>
1602 a.make_folded_const(val, ty, state);
1603 val.u.u64 = ~val.u.u64;
1604 val.u.u64 &= 1;
1605 }
1606};
1607
1608template<typename A>
1609HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1610 assert_is_lvalue_if_expr<A>();
1611 return {pattern_arg(a)};
1612}
1613
1614template<typename A>
1616 assert_is_lvalue_if_expr<A>();
1617 return IRMatcher::operator!(a);
1618}
1619
1620template<typename A>
1621inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1622 s << "!(" << op.a << ")";
1623 return s;
1624}
1625
1626template<typename C, typename T, typename F>
1627struct SelectOp {
1628 struct pattern_tag {};
1630 T t;
1631 F f;
1632
1634
1637
1638 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1639
1640 template<uint32_t bound>
1641 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1642 if (e.node_type != Select::_node_type) {
1643 return false;
1644 }
1645 const Select &op = (const Select &)e;
1646 return (c.template match<bound>(*op.condition.get(), state) &&
1647 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1648 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1649 }
1650 template<uint32_t bound, typename C2, typename T2, typename F2>
1651 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1652 return (c.template match<bound>(unwrap(instance.c), state) &&
1653 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1654 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1655 }
1656
1658 Expr make(MatcherState &state, halide_type_t type_hint) const {
1659 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1660 }
1661
1662 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1663
1664 template<typename C1 = C>
1666 halide_scalar_value_t c_val, t_val, f_val;
1667 halide_type_t c_ty;
1668 c.make_folded_const(c_val, c_ty, state);
1669 if ((c_val.u.u64 & 1) == 1) {
1670 t.make_folded_const(val, ty, state);
1671 } else {
1672 f.make_folded_const(val, ty, state);
1673 }
1674 ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1675 }
1676};
1677
1678template<typename C, typename T, typename F>
1679std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1680 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1681 return s;
1682}
1683
1684template<typename C, typename T, typename F>
1685HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1686 assert_is_lvalue_if_expr<C>();
1687 assert_is_lvalue_if_expr<T>();
1688 assert_is_lvalue_if_expr<F>();
1689 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1690}
1691
1692template<typename A, typename B>
1694 struct pattern_tag {};
1695 A a;
1697
1699
1702
1703 constexpr static bool canonical = A::canonical && B::canonical;
1704
1705 template<uint32_t bound>
1706 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1707 if (e.node_type == Broadcast::_node_type) {
1708 const Broadcast &op = (const Broadcast &)e;
1709 if (a.template match<bound>(*op.value.get(), state) &&
1710 lanes.template match<bound>(op.lanes, state)) {
1711 return true;
1712 }
1713 }
1714 return false;
1715 }
1716
1717 template<uint32_t bound, typename A2, typename B2>
1718 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1719 return (a.template match<bound>(unwrap(op.a), state) &&
1720 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1721 }
1722
1724 Expr make(MatcherState &state, halide_type_t type_hint) const {
1725 halide_scalar_value_t lanes_val;
1726 halide_type_t ty;
1727 lanes.make_folded_const(lanes_val, ty, state);
1728 int32_t l = (int32_t)lanes_val.u.i64;
1729 type_hint.lanes /= l;
1730 Expr val = a.make(state, type_hint);
1731 if (l == 1) {
1732 return val;
1733 } else {
1734 return Broadcast::make(std::move(val), l);
1735 }
1736 }
1737
1738 constexpr static bool foldable = false;
1739
1740 template<typename A1 = A>
1742 halide_scalar_value_t lanes_val;
1743 halide_type_t lanes_ty;
1744 lanes.make_folded_const(lanes_val, lanes_ty, state);
1745 uint16_t l = (uint16_t)lanes_val.u.i64;
1746 a.make_folded_const(val, ty, state);
1747 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1748 }
1749};
1750
1751template<typename A, typename B>
1752inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1753 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1754 return s;
1755}
1756
1757template<typename A, typename B>
1758HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1759 assert_is_lvalue_if_expr<A>();
1760 return {pattern_arg(a), pattern_arg(lanes)};
1761}
1762
1763template<typename A, typename B, typename C>
1764struct RampOp {
1765 struct pattern_tag {};
1766 A a;
1767 B b;
1769
1771
1774
1775 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1776
1777 template<uint32_t bound>
1778 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1779 if (e.node_type != Ramp::_node_type) {
1780 return false;
1781 }
1782 const Ramp &op = (const Ramp &)e;
1783 if (a.template match<bound>(*op.base.get(), state) &&
1784 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1785 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1786 return true;
1787 } else {
1788 return false;
1789 }
1790 }
1791
1792 template<uint32_t bound, typename A2, typename B2, typename C2>
1793 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1794 return (a.template match<bound>(unwrap(op.a), state) &&
1795 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1796 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1797 }
1798
1800 Expr make(MatcherState &state, halide_type_t type_hint) const {
1801 halide_scalar_value_t lanes_val;
1802 halide_type_t ty;
1803 lanes.make_folded_const(lanes_val, ty, state);
1804 int32_t l = (int32_t)lanes_val.u.i64;
1805 type_hint.lanes /= l;
1806 Expr ea, eb;
1807 eb = b.make(state, type_hint);
1808 ea = a.make(state, eb.type());
1809 return Ramp::make(ea, eb, l);
1810 }
1811
1812 constexpr static bool foldable = false;
1813};
1814
1815template<typename A, typename B, typename C>
1816std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1817 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1818 return s;
1819}
1820
1821template<typename A, typename B, typename C>
1822HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1823 assert_is_lvalue_if_expr<A>();
1824 assert_is_lvalue_if_expr<B>();
1825 assert_is_lvalue_if_expr<C>();
1826 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1827}
1828
1829template<typename A, typename B, VectorReduce::Operator reduce_op>
1831 struct pattern_tag {};
1832 A a;
1834
1836
1839 constexpr static bool canonical = A::canonical;
1840
1841 template<uint32_t bound>
1842 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1843 if (e.node_type == VectorReduce::_node_type) {
1844 const VectorReduce &op = (const VectorReduce &)e;
1845 if (op.op == reduce_op &&
1846 a.template match<bound>(*op.value.get(), state) &&
1847 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1848 return true;
1849 }
1850 }
1851 return false;
1852 }
1853
1854 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1856 return (reduce_op == reduce_op_2 &&
1857 a.template match<bound>(unwrap(op.a), state) &&
1858 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1859 }
1860
1862 Expr make(MatcherState &state, halide_type_t type_hint) const {
1863 halide_scalar_value_t lanes_val;
1864 halide_type_t ty;
1865 lanes.make_folded_const(lanes_val, ty, state);
1866 int l = (int)lanes_val.u.i64;
1867 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1868 }
1869
1870 constexpr static bool foldable = false;
1871};
1872
1873template<typename A, typename B, VectorReduce::Operator reduce_op>
1874inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1875 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1876 return s;
1877}
1878
1879template<typename A, typename B>
1880HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1881 assert_is_lvalue_if_expr<A>();
1882 return {pattern_arg(a), pattern_arg(lanes)};
1883}
1884
1885template<typename A, typename B>
1886HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1887 assert_is_lvalue_if_expr<A>();
1888 return {pattern_arg(a), pattern_arg(lanes)};
1889}
1890
1891template<typename A, typename B>
1892HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1893 assert_is_lvalue_if_expr<A>();
1894 return {pattern_arg(a), pattern_arg(lanes)};
1895}
1896
1897template<typename A, typename B>
1898HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1899 assert_is_lvalue_if_expr<A>();
1900 return {pattern_arg(a), pattern_arg(lanes)};
1901}
1902
1903template<typename A, typename B>
1904HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1905 assert_is_lvalue_if_expr<A>();
1906 return {pattern_arg(a), pattern_arg(lanes)};
1907}
1908
1909template<typename A>
1910struct NegateOp {
1911 struct pattern_tag {};
1912 A a;
1913
1915
1918
1919 constexpr static bool canonical = A::canonical;
1920
1921 template<uint32_t bound>
1922 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1923 if (e.node_type != Sub::_node_type) {
1924 return false;
1925 }
1926 const Sub &op = (const Sub &)e;
1927 return (a.template match<bound>(*op.b.get(), state) &&
1928 is_const_zero(op.a));
1929 }
1930
1931 template<uint32_t bound, typename A2>
1932 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1933 return a.template match<bound>(unwrap(p.a), state);
1934 }
1935
1937 Expr make(MatcherState &state, halide_type_t type_hint) const {
1938 Expr ea = a.make(state, type_hint);
1939 Expr z = make_zero(ea.type());
1940 return Sub::make(std::move(z), std::move(ea));
1941 }
1942
1943 constexpr static bool foldable = A::foldable;
1944
1945 template<typename A1 = A>
1947 a.make_folded_const(val, ty, state);
1948 int dead_bits = 64 - ty.bits;
1949 switch (ty.code) {
1950 case halide_type_int:
1951 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1952 // Trying to negate the most negative signed int for a no-overflow type.
1954 } else {
1955 // Negate, drop the high bits, and then sign-extend them back
1956 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
1957 }
1958 break;
1959 case halide_type_uint:
1960 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1961 break;
1962 case halide_type_float:
1963 case halide_type_bfloat:
1964 val.u.f64 = -val.u.f64;
1965 break;
1966 default:
1967 // unreachable
1968 ;
1969 }
1970 }
1971};
1972
1973template<typename A>
1974std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
1975 s << "-" << op.a;
1976 return s;
1977}
1978
1979template<typename A>
1980HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
1981 assert_is_lvalue_if_expr<A>();
1982 return {pattern_arg(a)};
1983}
1984
1985template<typename A>
1987 assert_is_lvalue_if_expr<A>();
1988 return IRMatcher::operator-(a);
1989}
1990
1991template<typename A>
1992struct CastOp {
1993 struct pattern_tag {};
1995 A a;
1996
1998
2001 constexpr static bool canonical = A::canonical;
2002
2003 template<uint32_t bound>
2004 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2005 if (e.node_type != Cast::_node_type) {
2006 return false;
2007 }
2008 const Cast &op = (const Cast &)e;
2009 return (e.type == t &&
2010 a.template match<bound>(*op.value.get(), state));
2011 }
2012 template<uint32_t bound, typename A2>
2013 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2014 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2015 }
2016
2018 Expr make(MatcherState &state, halide_type_t type_hint) const {
2019 return cast(t, a.make(state, {}));
2020 }
2021
2022 constexpr static bool foldable = false;
2023};
2024
2025template<typename A>
2026std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2027 s << "cast(" << op.t << ", " << op.a << ")";
2028 return s;
2029}
2030
2031template<typename A>
2032HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2033 assert_is_lvalue_if_expr<A>();
2034 return {t, pattern_arg(a)};
2035}
2036
2037template<typename A>
2038struct Fold {
2039 struct pattern_tag {};
2040 A a;
2041
2043
2046 constexpr static bool canonical = true;
2047
2049 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2051 halide_type_t ty = type_hint;
2052 a.make_folded_const(c, ty, state);
2053
2054 // The result of the fold may have an underspecified type
2055 // (e.g. because it's from an int literal). Make the type code
2056 // and bits match the required type, if there is one (we can
2057 // tell from the bits field).
2058 if (type_hint.bits) {
2059 if (((int)ty.code == (int)halide_type_int) &&
2060 ((int)type_hint.code == (int)halide_type_float)) {
2061 int64_t x = c.u.i64;
2062 c.u.f64 = (double)x;
2063 }
2064 ty.code = type_hint.code;
2065 ty.bits = type_hint.bits;
2066 }
2067
2068 Expr e = make_const_expr(c, ty);
2069 return e;
2070 }
2071
2072 constexpr static bool foldable = A::foldable;
2073
2074 template<typename A1 = A>
2076 a.make_folded_const(val, ty, state);
2077 }
2078};
2079
2080template<typename A>
2081HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2082 assert_is_lvalue_if_expr<A>();
2083 return {pattern_arg(a)};
2084}
2085
2086template<typename A>
2087std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2088 s << "fold(" << op.a << ")";
2089 return s;
2090}
2091
2092template<typename A>
2094 struct pattern_tag {};
2095 A a;
2096
2098
2099 // This rule is a predicate, so it always evaluates to a boolean,
2100 // which has IRNodeType UIntImm
2103 constexpr static bool canonical = true;
2104
2105 constexpr static bool foldable = A::foldable;
2106
2107 template<typename A1 = A>
2109 a.make_folded_const(val, ty, state);
2110 ty.code = halide_type_uint;
2111 ty.bits = 64;
2112 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2113 ty.lanes = 1;
2114 }
2115};
2116
2117template<typename A>
2118HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2119 assert_is_lvalue_if_expr<A>();
2120 return {pattern_arg(a)};
2121}
2122
2123template<typename A>
2124std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2125 s << "overflows(" << op.a << ")";
2126 return s;
2127}
2128
2129struct Overflow {
2130 struct pattern_tag {};
2131
2132 constexpr static uint32_t binds = 0;
2133
2134 // Overflow is an intrinsic, represented as a Call node
2137 constexpr static bool canonical = true;
2138
2139 template<uint32_t bound>
2140 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2141 if (e.node_type != Call::_node_type) {
2142 return false;
2143 }
2144 const Call &op = (const Call &)e;
2146 }
2147
2149 Expr make(MatcherState &state, halide_type_t type_hint) const {
2151 return make_const_special_expr(type_hint);
2152 }
2153
2154 constexpr static bool foldable = true;
2155
2158 val.u.u64 = 0;
2160 }
2161};
2162
2163inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2164 s << "overflow()";
2165 return s;
2166}
2167
2168template<typename A>
2169struct IsConst {
2170 struct pattern_tag {};
2171
2173
2174 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2177 constexpr static bool canonical = true;
2178
2179 A a;
2182
2183 constexpr static bool foldable = true;
2184
2185 template<typename A1 = A>
2187 Expr e = a.make(state, {});
2188 ty.code = halide_type_uint;
2189 ty.bits = 64;
2190 ty.lanes = 1;
2191 if (check_v) {
2192 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2193 } else {
2194 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2195 }
2196 }
2197};
2198
2199template<typename A>
2200HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2201 assert_is_lvalue_if_expr<A>();
2202 return {pattern_arg(a), false, 0};
2203}
2204
2205template<typename A>
2206HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2207 assert_is_lvalue_if_expr<A>();
2208 return {pattern_arg(a), true, value};
2209}
2210
2211template<typename A>
2212std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2213 if (op.check_v) {
2214 s << "is_const(" << op.a << ")";
2215 } else {
2216 s << "is_const(" << op.a << ", " << op.v << ")";
2217 }
2218 return s;
2219}
2220
2221template<typename A, typename Prover>
2222struct CanProve {
2223 struct pattern_tag {};
2224 A a;
2225 Prover *prover; // An existing simplifying mutator
2226
2228
2229 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2232 constexpr static bool canonical = true;
2233
2234 constexpr static bool foldable = true;
2235
2236 // Includes a raw call to an inlined make method, so don't inline.
2238 Expr condition = a.make(state, {});
2239 condition = prover->mutate(condition, nullptr);
2240 val.u.u64 = is_const_one(condition);
2242 ty.bits = 1;
2243 ty.lanes = condition.type().lanes();
2244 }
2245};
2246
2247template<typename A, typename Prover>
2248HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2249 assert_is_lvalue_if_expr<A>();
2250 return {pattern_arg(a), p};
2251}
2252
2253template<typename A, typename Prover>
2254std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2255 s << "can_prove(" << op.a << ")";
2256 return s;
2257}
2258
2259template<typename A>
2260struct IsFloat {
2261 struct pattern_tag {};
2262 A a;
2263
2265
2266 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2269 constexpr static bool canonical = true;
2270
2271 constexpr static bool foldable = true;
2272
2275 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2276 Type t = a.make(state, {}).type();
2277 val.u.u64 = t.is_float();
2279 ty.bits = 1;
2280 ty.lanes = t.lanes();
2281 }
2282};
2283
2284template<typename A>
2285HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2286 assert_is_lvalue_if_expr<A>();
2287 return {pattern_arg(a)};
2288}
2289
2290template<typename A>
2291std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2292 s << "is_float(" << op.a << ")";
2293 return s;
2294}
2295
2296template<typename A>
2297struct IsInt {
2298 struct pattern_tag {};
2299 A a;
2300 int bits;
2301
2303
2304 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2307 constexpr static bool canonical = true;
2308
2309 constexpr static bool foldable = true;
2310
2313 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2314 Type t = a.make(state, {}).type();
2315 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits);
2317 ty.bits = 1;
2318 ty.lanes = t.lanes();
2319 }
2320};
2321
2322template<typename A>
2323HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2324 assert_is_lvalue_if_expr<A>();
2325 return {pattern_arg(a), bits};
2326}
2327
2328template<typename A>
2329std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2330 s << "is_int(" << op.a;
2331 if (op.bits > 0) {
2332 s << ", " << op.bits;
2333 }
2334 s << ")";
2335 return s;
2336}
2337
2338template<typename A>
2339struct IsUInt {
2340 struct pattern_tag {};
2341 A a;
2342 int bits;
2343
2345
2346 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2349 constexpr static bool canonical = true;
2350
2351 constexpr static bool foldable = true;
2352
2355 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2356 Type t = a.make(state, {}).type();
2357 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits);
2359 ty.bits = 1;
2360 ty.lanes = t.lanes();
2361 }
2362};
2363
2364template<typename A>
2365HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2366 assert_is_lvalue_if_expr<A>();
2367 return {pattern_arg(a), bits};
2368}
2369
2370template<typename A>
2371std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2372 s << "is_uint(" << op.a;
2373 if (op.bits > 0) {
2374 s << ", " << op.bits;
2375 }
2376 s << ")";
2377 return s;
2378}
2379
2380template<typename A>
2381struct IsScalar {
2382 struct pattern_tag {};
2383 A a;
2384
2386
2387 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2390 constexpr static bool canonical = true;
2391
2392 constexpr static bool foldable = true;
2393
2396 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2397 Type t = a.make(state, {}).type();
2398 val.u.u64 = t.is_scalar();
2400 ty.bits = 1;
2401 ty.lanes = t.lanes();
2402 }
2403};
2404
2405template<typename A>
2406HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2407 assert_is_lvalue_if_expr<A>();
2408 return {pattern_arg(a)};
2409}
2410
2411template<typename A>
2412std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2413 s << "is_scalar(" << op.a << ")";
2414 return s;
2415}
2416
2417template<typename A>
2419 struct pattern_tag {};
2420 A a;
2421
2423
2424 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2427 constexpr static bool canonical = true;
2428
2429 constexpr static bool foldable = true;
2430
2433 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2434 a.make_folded_const(val, ty, state);
2435 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2436 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2437 val.u.u64 = (val.u.u64 == max_bits);
2438 } else {
2439 val.u.u64 = 0;
2440 }
2442 ty.bits = 1;
2443 }
2444};
2445
2446template<typename A>
2447HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2448 assert_is_lvalue_if_expr<A>();
2449 return {pattern_arg(a)};
2450}
2451
2452template<typename A>
2453std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2454 s << "is_max_value(" << op.a << ")";
2455 return s;
2456}
2457
2458template<typename A>
2460 struct pattern_tag {};
2461 A a;
2462
2464
2465 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2468 constexpr static bool canonical = true;
2469
2470 constexpr static bool foldable = true;
2471
2474 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2475 a.make_folded_const(val, ty, state);
2476 if (ty.code == halide_type_int) {
2477 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2478 val.u.u64 = (val.u.u64 == min_bits);
2479 } else if (ty.code == halide_type_uint) {
2480 val.u.u64 = (val.u.u64 == 0);
2481 } else {
2482 val.u.u64 = 0;
2483 }
2485 ty.bits = 1;
2486 }
2487};
2488
2489template<typename A>
2490HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2491 assert_is_lvalue_if_expr<A>();
2492 return {pattern_arg(a)};
2493}
2494
2495template<typename A>
2496std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2497 s << "is_min_value(" << op.a << ")";
2498 return s;
2499}
2500
2501// Verify properties of each rewrite rule. Currently just fuzz tests them.
2502template<typename Before,
2503 typename After,
2504 typename Predicate,
2505 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2506 std::decay<After>::type::foldable>::type>
2507HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2508 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2509
2510 // We only validate the rules in the scalar case
2511 wildcard_type.lanes = output_type.lanes = 1;
2512
2513 // Track which types this rule has been tested for before
2514 static std::set<uint32_t> tested;
2515
2516 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2517 return;
2518 }
2519
2520 // Print it in a form where it can be piped into a python/z3 validator
2521 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2522
2523 // Substitute some random constants into the before and after
2524 // expressions and see if the rule holds true. This should catch
2525 // silly errors, but not necessarily corner cases.
2526 static std::mt19937_64 rng(0);
2527 MatcherState state;
2528
2529 Expr exprs[max_wild];
2530
2531 for (int trials = 0; trials < 100; trials++) {
2532 // We want to test small constants more frequently than
2533 // large ones, otherwise we'll just get coverage of
2534 // overflow rules.
2535 int shift = (int)(rng() & (wildcard_type.bits - 1));
2536
2537 for (int i = 0; i < max_wild; i++) {
2538 // Bind all the exprs and constants
2539 switch (wildcard_type.code) {
2540 case halide_type_uint: {
2541 // Normalize to the type's range by adding zero
2542 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2543 state.set_bound_const(i, val, wildcard_type);
2544 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2545 exprs[i] = make_const(wildcard_type, val);
2546 state.set_binding(i, *exprs[i].get());
2547 } break;
2548 case halide_type_int: {
2549 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2550 state.set_bound_const(i, val, wildcard_type);
2551 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2552 exprs[i] = make_const(wildcard_type, val);
2553 } break;
2554 case halide_type_float:
2555 case halide_type_bfloat: {
2556 // Use a very narrow range of precise floats, so
2557 // that none of the rules a human is likely to
2558 // write have instabilities.
2559 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2560 state.set_bound_const(i, val, wildcard_type);
2561 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2562 exprs[i] = make_const(wildcard_type, val);
2563 } break;
2564 default:
2565 return; // Don't care about handles
2566 }
2567 state.set_binding(i, *exprs[i].get());
2568 }
2569
2570 halide_scalar_value_t val_pred, val_before, val_after;
2571 halide_type_t type = output_type;
2572 if (!evaluate_predicate(pred, state)) {
2573 continue;
2574 }
2575 before.make_folded_const(val_before, type, state);
2576 uint16_t lanes = type.lanes;
2577 after.make_folded_const(val_after, type, state);
2578 lanes |= type.lanes;
2579
2581 continue;
2582 }
2583
2584 bool ok = true;
2585 switch (output_type.code) {
2586 case halide_type_uint:
2587 // Compare normalized representations
2588 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2589 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2590 break;
2591 case halide_type_int:
2592 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2593 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2594 break;
2595 case halide_type_float:
2596 case halide_type_bfloat: {
2597 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2598 // We accept an equal bit pattern (e.g. inf vs inf),
2599 // a small floating point difference, or turning a nan into not-a-nan.
2600 ok &= (error < 0.01 ||
2601 val_before.u.u64 == val_after.u.u64 ||
2602 std::isnan(val_before.u.f64));
2603 break;
2604 }
2605 default:
2606 return;
2607 }
2608
2609 if (!ok) {
2610 debug(0) << "Fails with values:\n";
2611 for (int i = 0; i < max_wild; i++) {
2613 state.get_bound_const(i, val, wildcard_type);
2614 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2615 }
2616 for (int i = 0; i < max_wild; i++) {
2617 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2618 }
2619 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2620 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2621 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2623 }
2624 }
2625}
2626
2627template<typename Before,
2628 typename After,
2629 typename Predicate,
2630 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2631 std::decay<After>::type::foldable)>::type>
2632HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2633 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2634 // We can't verify rewrite rules that can't be constant-folded.
2635}
2636
2638bool evaluate_predicate(bool x, MatcherState &) noexcept {
2639 return x;
2640}
2641
2642template<typename Pattern,
2643 typename = typename enable_if_pattern<Pattern>::type>
2646 halide_type_t ty = halide_type_of<bool>();
2647 p.make_folded_const(c, ty, state);
2648 // Overflow counts as a failed predicate
2649 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2650}
2651
2652// #defines for testing
2653
2654// Print all successful or failed matches
2655#define HALIDE_DEBUG_MATCHED_RULES 0
2656#define HALIDE_DEBUG_UNMATCHED_RULES 0
2657
2658// Set to true if you want to fuzz test every rewrite passed to
2659// operator() to ensure the input and the output have the same value
2660// for lots of random values of the wildcards. Run
2661// correctness_simplify with this on.
2662#define HALIDE_FUZZ_TEST_RULES 0
2663
2664template<typename Instance>
2665struct Rewriter {
2666 Instance instance;
2671
2674 : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2675 }
2676
2677 template<typename After>
2679 result = after.make(state, output_type);
2680 }
2681
2682 template<typename Before,
2683 typename After,
2684 typename = typename enable_if_pattern<Before>::type,
2685 typename = typename enable_if_pattern<After>::type>
2686 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2687 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2688 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2689 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2690#if HALIDE_FUZZ_TEST_RULES
2691 fuzz_test_rule(before, after, true, wildcard_type, output_type);
2692#endif
2693 if (before.template match<0>(unwrap(instance), state)) {
2694 build_replacement(after);
2695#if HALIDE_DEBUG_MATCHED_RULES
2696 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2697#endif
2698 return true;
2699 } else {
2700#if HALIDE_DEBUG_UNMATCHED_RULES
2701 debug(0) << instance << " does not match " << before << "\n";
2702#endif
2703 return false;
2704 }
2705 }
2706
2707 template<typename Before,
2708 typename = typename enable_if_pattern<Before>::type>
2709 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2710 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2711 if (before.template match<0>(unwrap(instance), state)) {
2712 result = after;
2713#if HALIDE_DEBUG_MATCHED_RULES
2714 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2715#endif
2716 return true;
2717 } else {
2718#if HALIDE_DEBUG_UNMATCHED_RULES
2719 debug(0) << instance << " does not match " << before << "\n";
2720#endif
2721 return false;
2722 }
2723 }
2724
2725 template<typename Before,
2726 typename = typename enable_if_pattern<Before>::type>
2727 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2728 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2729#if HALIDE_FUZZ_TEST_RULES
2730 fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2731#endif
2732 if (before.template match<0>(unwrap(instance), state)) {
2733 result = make_const(output_type, after);
2734#if HALIDE_DEBUG_MATCHED_RULES
2735 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2736#endif
2737 return true;
2738 } else {
2739#if HALIDE_DEBUG_UNMATCHED_RULES
2740 debug(0) << instance << " does not match " << before << "\n";
2741#endif
2742 return false;
2743 }
2744 }
2745
2746 template<typename Before,
2747 typename After,
2748 typename Predicate,
2749 typename = typename enable_if_pattern<Before>::type,
2750 typename = typename enable_if_pattern<After>::type,
2751 typename = typename enable_if_pattern<Predicate>::type>
2752 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2753 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2754 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2755 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2756 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2757 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2758
2759#if HALIDE_FUZZ_TEST_RULES
2760 fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2761#endif
2762 if (before.template match<0>(unwrap(instance), state) &&
2763 evaluate_predicate(pred, state)) {
2764 build_replacement(after);
2765#if HALIDE_DEBUG_MATCHED_RULES
2766 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2767#endif
2768 return true;
2769 } else {
2770#if HALIDE_DEBUG_UNMATCHED_RULES
2771 debug(0) << instance << " does not match " << before << "\n";
2772#endif
2773 return false;
2774 }
2775 }
2776
2777 template<typename Before,
2778 typename Predicate,
2779 typename = typename enable_if_pattern<Before>::type,
2780 typename = typename enable_if_pattern<Predicate>::type>
2781 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2782 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2783 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2784
2785 if (before.template match<0>(unwrap(instance), state) &&
2786 evaluate_predicate(pred, state)) {
2787 result = after;
2788#if HALIDE_DEBUG_MATCHED_RULES
2789 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2790#endif
2791 return true;
2792 } else {
2793#if HALIDE_DEBUG_UNMATCHED_RULES
2794 debug(0) << instance << " does not match " << before << "\n";
2795#endif
2796 return false;
2797 }
2798 }
2799
2800 template<typename Before,
2801 typename Predicate,
2802 typename = typename enable_if_pattern<Before>::type,
2803 typename = typename enable_if_pattern<Predicate>::type>
2804 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2805 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2806 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2807#if HALIDE_FUZZ_TEST_RULES
2808 fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2809#endif
2810 if (before.template match<0>(unwrap(instance), state) &&
2811 evaluate_predicate(pred, state)) {
2812 result = make_const(output_type, after);
2813#if HALIDE_DEBUG_MATCHED_RULES
2814 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2815#endif
2816 return true;
2817 } else {
2818#if HALIDE_DEBUG_UNMATCHED_RULES
2819 debug(0) << instance << " does not match " << before << "\n";
2820#endif
2821 return false;
2822 }
2823 }
2824};
2825
2826/** Construct a rewriter for the given instance, which may be a pattern
2827 * with concrete expressions as leaves, or just an expression. The
2828 * second optional argument (wildcard_type) is a hint as to what the
2829 * type of the wildcards is likely to be. If omitted it uses the same
2830 * type as the expression itself. They are not required to be this
2831 * type, but the rule will only be tested for wildcards of that type
2832 * when testing is enabled.
2833 *
2834 * The rewriter can be used to check to see if the instance is one of
2835 * some number of patterns and if so rewrite it into another form,
2836 * using its operator() method. See Simplify.cpp for a bunch of
2837 * example usage.
2838 *
2839 * Important: Any Exprs in patterns are captured by reference, not by
2840 * value, so ensure they outlive the rewriter.
2841 */
2842// @{
2843template<typename Instance,
2844 typename = typename enable_if_pattern<Instance>::type>
2845HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2846 return {pattern_arg(instance), output_type, wildcard_type};
2847}
2848
2849template<typename Instance,
2850 typename = typename enable_if_pattern<Instance>::type>
2851HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2852 return {pattern_arg(instance), output_type, output_type};
2853}
2854
2856auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
2857 return {pattern_arg(e), e.type(), wildcard_type};
2858}
2859
2861auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
2862 return {pattern_arg(e), e.type(), e.type()};
2863}
2864// @}
2865
2866} // namespace IRMatcher
2867
2868} // namespace Internal
2869} // namespace Halide
2870
2871#endif
#define internal_error
Definition: Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:39
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:38
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1552
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1544
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:2845
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:602
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1310
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1609
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1111
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2323
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2638
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1067
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1285
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:1986
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1205
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:956
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2447
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:229
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1336
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1898
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1185
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2200
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition: IRMatch.h:1503
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1215
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1022
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1536
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1556
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:963
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1062
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1520
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:1029
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1133
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1822
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1055
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1516
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1096
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1341
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:594
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1180
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2032
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2118
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1508
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:611
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1082
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1003
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2406
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2081
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1615
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1528
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1140
constexpr bool and_reduce()
Definition: IRMatch.h:1365
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1305
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1512
constexpr int max_wild
Definition: IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1280
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition: IRMatch.h:195
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2285
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1230
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1155
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1331
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1904
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:650
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:996
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1892
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1758
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1685
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2490
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1118
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2507
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1190
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1532
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1524
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1036
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2365
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1560
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1548
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1240
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:989
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1210
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1160
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2206
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1165
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1886
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1880
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1315
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1356
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1564
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1265
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1235
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1375
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1290
auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1540
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1089
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1255
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:970
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2248
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1260
T div_imp(T a, T b)
Definition: IROperator.h:260
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:79
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:239
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:256
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:320
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:314
The sum of two expressions.
Definition: IR.h:38
Logical and - are both expressions true.
Definition: IR.h:157
A base class for expression nodes.
Definition: Expr.h:141
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:241
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition: IR.h:247
A function call.
Definition: IR.h:466
@ signed_integer_overflow
Definition: IR.h:554
@ rounding_mul_shift_right
Definition: IR.h:545
bool is_intrinsic() const
Definition: IR.h:646
static const IRNodeType _node_type
Definition: IR.h:691
The actual IR nodes begin here.
Definition: IR.h:29
static const IRNodeType _node_type
Definition: IR.h:34
The ratio of two expressions.
Definition: IR.h:65
Is the first expression equal to the second.
Definition: IR.h:103
Floating point constants.
Definition: Expr.h:234
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition: IR.h:148
Is the first expression greater than the second.
Definition: IR.h:139
constexpr static uint32_t binds
Definition: IRMatch.h:668
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:671
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:699
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:680
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:670
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:742
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:690
constexpr static bool canonical
Definition: IRMatch.h:676
constexpr static bool foldable
Definition: IRMatch.h:696
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1724
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1718
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1706
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1700
constexpr static uint32_t binds
Definition: IRMatch.h:1698
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1741
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1701
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2237
constexpr static bool foldable
Definition: IRMatch.h:2234
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2230
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2231
constexpr static uint32_t binds
Definition: IRMatch.h:2227
constexpr static bool canonical
Definition: IRMatch.h:2232
constexpr static bool canonical
Definition: IRMatch.h:2001
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2000
constexpr static bool foldable
Definition: IRMatch.h:2022
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2004
constexpr static uint32_t binds
Definition: IRMatch.h:1997
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1999
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2013
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2018
constexpr static bool canonical
Definition: IRMatch.h:783
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:843
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:781
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:782
constexpr static bool foldable
Definition: IRMatch.h:806
constexpr static uint32_t binds
Definition: IRMatch.h:779
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:790
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:809
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:800
constexpr static bool foldable
Definition: IRMatch.h:2072
constexpr static uint32_t binds
Definition: IRMatch.h:2042
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2044
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2045
constexpr static bool canonical
Definition: IRMatch.h:2046
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2049
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2075
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:539
constexpr static bool canonical
Definition: IRMatch.h:531
constexpr static uint32_t binds
Definition: IRMatch.h:527
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:534
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:562
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:574
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:529
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:530
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:567
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:557
constexpr static bool foldable
Definition: IRMatch.h:571
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1401
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1434
constexpr static bool canonical
Definition: IRMatch.h:1389
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1429
constexpr static bool foldable
Definition: IRMatch.h:1486
static constexpr uint32_t binds
Definition: IRMatch.h:1385
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1394
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1416
std::tuple< Args... > args
Definition: IRMatch.h:1383
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1406
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1425
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition: IRMatch.h:1489
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1388
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1387
constexpr static bool canonical
Definition: IRMatch.h:2177
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2186
constexpr static bool foldable
Definition: IRMatch.h:2183
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2176
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2175
constexpr static uint32_t binds
Definition: IRMatch.h:2172
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2267
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2274
constexpr static bool canonical
Definition: IRMatch.h:2269
constexpr static uint32_t binds
Definition: IRMatch.h:2264
constexpr static bool foldable
Definition: IRMatch.h:2271
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2268
constexpr static uint32_t binds
Definition: IRMatch.h:2302
constexpr static bool foldable
Definition: IRMatch.h:2309
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2305
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2312
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2306
constexpr static bool canonical
Definition: IRMatch.h:2307
constexpr static bool canonical
Definition: IRMatch.h:2427
constexpr static bool foldable
Definition: IRMatch.h:2429
constexpr static uint32_t binds
Definition: IRMatch.h:2422
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2425
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2426
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2432
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2466
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2467
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2473
constexpr static bool canonical
Definition: IRMatch.h:2468
constexpr static uint32_t binds
Definition: IRMatch.h:2463
constexpr static bool foldable
Definition: IRMatch.h:2470
constexpr static bool foldable
Definition: IRMatch.h:2392
constexpr static bool canonical
Definition: IRMatch.h:2390
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2395
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2389
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2388
constexpr static uint32_t binds
Definition: IRMatch.h:2385
constexpr static bool canonical
Definition: IRMatch.h:2349
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2354
constexpr static uint32_t binds
Definition: IRMatch.h:2344
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2347
constexpr static bool foldable
Definition: IRMatch.h:2351
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2348
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1916
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1917
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1922
constexpr static uint32_t binds
Definition: IRMatch.h:1914
constexpr static bool canonical
Definition: IRMatch.h:1919
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1937
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1932
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1946
constexpr static bool foldable
Definition: IRMatch.h:1943
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1575
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1580
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1589
constexpr static uint32_t binds
Definition: IRMatch.h:1573
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1576
constexpr static bool foldable
Definition: IRMatch.h:1598
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1594
constexpr static bool canonical
Definition: IRMatch.h:1577
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1601
constexpr static bool canonical
Definition: IRMatch.h:2137
constexpr static bool foldable
Definition: IRMatch.h:2154
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2140
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2149
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2136
constexpr static uint32_t binds
Definition: IRMatch.h:2132
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2157
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2135
constexpr static bool foldable
Definition: IRMatch.h:2105
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2108
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2102
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2101
constexpr static uint32_t binds
Definition: IRMatch.h:2097
constexpr static bool canonical
Definition: IRMatch.h:2103
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1800
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1773
constexpr static bool canonical
Definition: IRMatch.h:1775
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1772
constexpr static bool foldable
Definition: IRMatch.h:1812
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1793
constexpr static uint32_t binds
Definition: IRMatch.h:1770
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1778
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2678
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2752
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2727
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2673
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2781
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2709
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:2804
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2686
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1636
constexpr static bool canonical
Definition: IRMatch.h:1638
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1665
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1635
constexpr static uint32_t binds
Definition: IRMatch.h:1633
constexpr static bool foldable
Definition: IRMatch.h:1662
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1651
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1641
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1658
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:210
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:222
constexpr static uint32_t binds
Definition: IRMatch.h:207
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:211
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1855
constexpr static uint32_t binds
Definition: IRMatch.h:1835
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1838
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1842
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1837
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1862
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:364
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:360
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:385
constexpr static uint32_t binds
Definition: IRMatch.h:357
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:359
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:395
constexpr static uint32_t binds
Definition: IRMatch.h:411
constexpr static bool foldable
Definition: IRMatch.h:450
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:443
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:414
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:418
constexpr static bool canonical
Definition: IRMatch.h:415
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:453
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:413
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:437
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:240
constexpr static uint32_t binds
Definition: IRMatch.h:238
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:279
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:245
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:266
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:241
constexpr static uint32_t binds
Definition: IRMatch.h:304
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:307
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:306
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:311
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:342
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:332
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:472
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:491
constexpr static bool foldable
Definition: IRMatch.h:489
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:485
constexpr static bool canonical
Definition: IRMatch.h:473
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:471
constexpr static uint32_t binds
Definition: IRMatch.h:469
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:476
constexpr static uint32_t mask
Definition: IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:111
Integer constants.
Definition: Expr.h:216
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition: IR.h:130
Is the first expression less than the second.
Definition: IR.h:121
The greater of two values.
Definition: IR.h:94
The lesser of two values.
Definition: IR.h:85
The remainder of a / b.
Definition: IR.h:76
The product of two expressions.
Definition: IR.h:56
Is the first expression not equal to the second.
Definition: IR.h:112
Logical not - true if the expression false.
Definition: IR.h:175
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition: IR.h:166
A linear ramp vector node.
Definition: IR.h:229
static const IRNodeType _node_type
Definition: IR.h:235
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition: IR.h:186
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition: IR.h:191
The difference of two expressions.
Definition: IR.h:47
static const IRNodeType _node_type
Definition: IR.h:52
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition: Expr.h:225
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:885
static const IRNodeType _node_type
Definition: IR.h:904
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition: Type.h:265
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:402
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:333
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:408
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:327
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition: Type.h:377
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:384
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:390
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.