In this tutorial we will learn how to build, compose, and transform Iteration/Expression Trees (IETs).

# Part II - Bottom Up

`Dimensions` are the building blocks of both `Iterations` and `Expressions`.

``````from devito import SpaceDimension, TimeDimension

dims = {'i': SpaceDimension(name='i'),
'j': SpaceDimension(name='j'),
'k': SpaceDimension(name='k'),
't0': TimeDimension(name='t0'),
't1': TimeDimension(name='t1')}

dims``````
``{'i': i, 'j': j, 'k': k, 't0': t0, 't1': t1}``

Elements such as `Scalars`, `Constants` and `Functions` are used to build SymPy equations.

``````from devito import Grid, Constant, Function, TimeFunction
from devito.types import Array, Scalar

grid = Grid(shape=(10, 10))
symbs = {'a': Scalar(name='a'),
'b': Constant(name='b'),
'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(),
'd': Array(name='d',
shape=(3,3),
dimensions=(dims['j'],dims['k'])).indexify(),
'e': Function(name='e',
shape=(3,3,3),
dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(),
'f': TimeFunction(name='f', grid=grid).indexify()}
symbs``````
``{'a': a, 'b': b, 'c': c[i], 'd': d[j, k], 'e': e[t0, t1, i], 'f': f[t, x, y]}``

An IET `Expression` wraps a SymPy equation. Below, `DummyEq` is a subclass of `sympy.Eq` with some metadata attached. What, when and how metadata are attached is here irrelevant.

``````from devito.ir.iet import Expression
from devito.ir.equations import DummyEq
from devito.tools import pprint

def get_exprs(a, b, c, d, e, f):
return [Expression(DummyEq(a, b + c + 5.)),
Expression(DummyEq(d, e - f)),
Expression(DummyEq(a, 4 * (b * a))),
Expression(DummyEq(a, (6. / b) + (8. * a)))]

exprs = get_exprs(symbs['a'],
symbs['b'],
symbs['c'],
symbs['d'],
symbs['e'],
symbs['f'])

pprint(exprs)``````
``````<Expression a = b + c[i] + 5.0>
<Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>
<Expression a = 4*b*a>
<Expression a = 8.0*a + 6.0/b>``````

An `Iteration` typically wraps one or more `Expression`s.

``````from devito.ir.iet import Iteration

def get_iters(dims):
return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)),
lambda ex: Iteration(ex, dims['j'], (0, 5, 1)),
lambda ex: Iteration(ex, dims['k'], (0, 7, 1)),
lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)),
lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))]

iters = get_iters(dims)``````

Here, we can see how blocks of `Iterations` over `Expressions` can be used to build loop nests.

``````def get_block1(exprs, iters):
# Perfect loop nest:
# for i
#   for j
#     for k
#       expr0
return iters[0](iters[1](iters[2](exprs[0])))

def get_block2(exprs, iters):
# Non-perfect simple loop nest:
# for i
#   expr0
#   for j
#     for k
#       expr1
return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))])

def get_block3(exprs, iters):
# Non-perfect non-trivial loop nest:
# for i
#   for s
#     expr0
#   for j
#     for k
#       expr1
#       expr2
#   for p
#     expr3
return iters[0]([iters[3](exprs[0]),
iters[1](iters[2]([exprs[1], exprs[2]])),
iters[4](exprs[3])])

block1 = get_block1(exprs, iters)
block2 = get_block2(exprs, iters)
block3 = get_block3(exprs, iters)

pprint(block1), print('\n')
pprint(block2), print('\n')
pprint(block3)``````
``````<Iteration i::i::(0, 3, 1)>
<Iteration j::j::(0, 5, 1)>
<Iteration k::k::(0, 7, 1)>
<Expression a = b + c[i] + 5.0>

<Iteration i::i::(0, 3, 1)>
<Expression a = b + c[i] + 5.0>
<Iteration j::j::(0, 5, 1)>
<Iteration k::k::(0, 7, 1)>
<Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>

<Iteration i::i::(0, 3, 1)>
<Iteration t0::t0::(0, 4, 1)>
<Expression a = b + c[i] + 5.0>
<Iteration j::j::(0, 5, 1)>
<Iteration k::k::(0, 7, 1)>
<Expression d[j, k] = e[t0, t1, i] - f[t, x, y]>
<Expression a = 4*b*a>
<Iteration t1::t1::(0, 4, 1)>
<Expression a = 8.0*a + 6.0/b>``````

And, finally, we can build `Callable` kernels that will be used to generate C code. Note that `Operator` is a subclass of `Callable`.

``````from devito.ir.iet import Callable

kernels = [Callable('foo', block1, 'void', ()),
Callable('foo', block2, 'void', ()),
Callable('foo', block3, 'void', ())]

print('kernel no.1:\n' + str(kernels[0].ccode) + '\n')
print('kernel no.2:\n' + str(kernels[1].ccode) + '\n')
print('kernel no.3:\n' + str(kernels[2].ccode) + '\n')``````
``````kernel no.1:
void foo()
{
for (int i = 0; i <= 3; i += 1)
{
for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
a = b + c[i] + 5.0F;
}
}
}
}

kernel no.2:
void foo()
{
for (int i = 0; i <= 3; i += 1)
{
a = b + c[i] + 5.0F;

for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
d[j][k] = e[t0][t1][i] - f[t][x][y];
}
}
}
}

kernel no.3:
void foo()
{
for (int i = 0; i <= 3; i += 1)
{
for (int t0 = 0; t0 <= 4; t0 += 1)
{
a = b + c[i] + 5.0F;
}
for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
d[j][k] = e[t0][t1][i] - f[t][x][y];
a = 4*b*a;
}
}
for (int t1 = 0; t1 <= 4; t1 += 1)
{
a = 8.0F*a + 6.0F/b;
}
}
}
``````

An IET is immutable. It can be “transformed” by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by `Transformer` visitors. A `Transformer` takes in input a dictionary encoding replacement rules.

``````from devito.ir.iet import Transformer

# Replaces a Function's body with another
transformer = Transformer({block1: block2})
kernel_alt = transformer.visit(kernels[0])
print(kernel_alt)``````
``````void foo()
{
for (int i = 0; i <= 3; i += 1)
{
a = b + c[i] + 5.0F;

for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
d[j][k] = e[t0][t1][i] - f[t][x][y];
}
}
}
}``````

Specific `Expression`s within the loop nest can also be substituted.

``````# Replaces an expression with another
transformer = Transformer({exprs[0]: exprs[1]})
newblock = transformer.visit(block1)
newcode = str(newblock.ccode)
print(newcode)``````
``````for (int i = 0; i <= 3; i += 1)
{
for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
d[j][k] = e[t0][t1][i] - f[t][x][y];
}
}
}``````
``````from devito.ir.iet import Block
import cgen as c

# Creates a replacer for replacing an expression
line1 = '// Replaced expression'
replacer = Block(c.Line(line1))
transformer = Transformer({exprs[1]: replacer})
newblock = transformer.visit(block2)
newcode = str(newblock.ccode)
print(newcode)``````
``````for (int i = 0; i <= 3; i += 1)
{
a = b + c[i] + 5.0F;

for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
// Replaced expression
{
}
}
}
}``````
``````# Wraps an expression in comments
line1 = '// This is the opening comment'
line2 = '// This is the closing comment'
wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))
transformer = Transformer({exprs[0]: wrapper(exprs[0])})
newblock = transformer.visit(block1)
newcode = str(newblock.ccode)
print(newcode)``````
``````for (int i = 0; i <= 3; i += 1)
{
for (int j = 0; j <= 5; j += 1)
{
for (int k = 0; k <= 7; k += 1)
{
// This is the opening comment
{
a = b + c[i] + 5.0F;
}
// This is the closing comment
}
}
}``````