Expression Templates

Uneffective object-oriented code

Object-orientation sometimes sacrifies efficiency for readibility and ease of code maintenance:


Array a, b, c, d, e;

a = b + c + d + e;

This will generate the following pseudo-code:


// pseudocode for a = b + c + d + e

double* _t1 = new double[N];
for ( int i=0; i<N; ++i)
    _t1[i] = b[i] + c[i];

double* _t2 = new double[N*M];
for ( int i=0; i<N; ++i)
    _t2[i] = _t1[i] + d[i];

double* _t3 = new double[N*M];
for ( int i=0; i<N; ++i)
    _t3[i] = _t2[i] + e[i];

for ( int i=0; i<N; ++i)
    a[i] = _t3[i];

delete [] _t3;
delete [] _t2;
delete [] _t1;

Performace problems

- For small arrays new and delete result poor performance: 1/10 of C.

- For medium arrays, overhead of extra loops and memory access add +50%

- For large arrays, the cost of the temporaries are the limitations: by Veldhuisen this could be 1/7 to 1/27 of the C or Fortran versions.

Recursive Templates


template <class Left, class Right>
class X { };

X<A, X<B, X<C, X<D, END> > > >   a;     // list
X< X<A,B>, X<C,D> >     a;      // tree

Using recursive templates we are able to build a parser-tree:


Array A, B, C, D;

D = A + B +C ;


A + B + C

X< Array, plus, X<Array, plus, Array> >


struct plus { };    // represent addition
class  Array { };   // represent a node in parse tree


template <typename Left, typename Op, typename Right>
class X { };


template <class T>
X<T, plus, Array> operator+( T, Array)
{
    return X< T, plus, Array>();
}

Array A, B, C, D;

D = A + B + C;

  = X<Array,plus,Array>() + C;

  = X< X<Array,plus,Array>, plus, Array>();

Minimal Implementation


#include <iostream>

using namespace std;

// this class encapsulates the "+" operation.
struct plus
{
    static double apply( double a, double b)
    {
        return a+b;
    }
};


// the node in the parse tree.
template <typename Left, typename Op, typename Right>
struct X
{
    Left    left;
    Right   right;

    X( Left t1, Right t2) : left(t1), right(t2)  { }

    double operator[](int i)
    {
        return Op::apply( left[i], right[i] );
    }
};


struct Array
{
    // constructor
    Array( double *data, int N) : data_(data), N_(N) { }

    // assign an expression to the array
    template <typename Left, typename Op, typename Right>
    void operator=( X<Left,Op,Right> expr)
    {
        for ( int i = 0; i < N_; ++i)
            data_[i] = expr[i];
    }

    double operator[](int i)
    {
        return data_[i];
    }

    double *data_;
    int N_;
};


template <typename Left>
X<Left, plus, Array> operator+( Left a, Array b)
{
    return X<Left, plus, Array>(a,b);
}


int main()
{
    double a_data[] = { 2, 3, 5, 9 };
    double b_data[] = { 1, 0, 0, 1 };
    double c_data[] = { 3, 0, 2, 5 };
    double d_data[4];

    Array A(a_data,4);
    Array B(b_data,4);
    Array C(c_data,4);
    Array D(d_data,4);

    D = A + B + C;

    for ( int i = 0; i < 4; ++i )
        cout << D[i] << " ";
    cout << endl;
}

What happens in compilation-time?


    D = A + B + C;

      = X<Array,plus,Array>(A,B) + C;

      = X< X<Array,plus,Array>, plus, Array>( X<Array,plus,Array>(A,B), C);

then it matches template Array::operator=


D.operator=(X<X<Array,plus,Array>,plus,Array>(X<Array,plus,Array>(A,B),C) expr)
{
    for ( int i = 0; i < N_; ++i)
        data_[i] = expr[i];
}

The expr[i] is expanded by inlining operator[] from each node from parse tree:


data_[i] = plus::apply( X<Array,plus,Array>(A,B)[i], C[i]);

         = plus::apply( A[i] + B[i] + C[i]);

         = A[i] + B[i] + C[i];

So the final result of D = A + B + C is:


for ( int i = 0; i < N_; ++i)
    D.data_[i] = A.data_[i] + B.data_[i] + C.data_[i];

No temporaries, and a single loop!

Do not do an expression templates implementation yourself, except for fun. There are several good implementations like:

- Blitz++ http://www.oonumerics.org/blitz

- PETE http://www.acl.lanl.gov/pete

Expression Templates in Java

Just when you thought your little language was safe: Expression Templates in Java - Todd Veldhuizen, Erfurt 2000


W = X + Y * Z


public class DoArrayStuff
{
    public static apply( float[] w, float[] x, float[] y, float[] z)
    {
        for ( int i = 0; i < w.length; ++i)
            w[i] = x[i] + y[i] * z[i];
    }
}
// or:
public static apply( Array w, Array x, Array y, Array z)
{
    w = x + y * z;
}

JavaTran == Java+Fortan


int n = 1000;

Array w = new Array(n);
Array x = new Array(n);
Array y = new Array(n);
Array z = new Array(n);

w = x + y * z;


// since we do not have operator overloading:

w.assign(x.plus(y.times(z)));

/*

        Expr                    BinaryOperator
        /  \                        /    \
       /    \                      /      \
   Array  BinaryOpExpr           Plus   Times

*/

//  y * z
Expr e = new BinaryOpExpr( y, z, new Times());


public abstract class Expr
{
    public abstract float eval(int i);

    public Expr plus(Expr b)
    {
        BinaryOperator plus = new Plus();
        return new BinaryOpExpr(this,b,plus);
    }
    public Expr times(Expr b)
    {
        BinaryOperator times = new Times();
        return new BinaryOpExpr(this,b,times);
    }
}

public class BinaryOpExpr extends Expr
{
    Expr a, b;
    BinaryOperator op;

    public BinaryOpExpr( Expr a_, Expr b_, BinaryOperator op_)
    {
        a = a_;
        b = b_;
        op = op_;
    }
    public float eval(int i)
    {
        return op.apply(a.eval(i),b.eval(i));
    }
}

public abstract class BinaryOperator
{
    public abstract float apply(float a, float b);}

public class Plus extends BinaryOperator
{
        public float apply(float a, float b) { return a+b; }
}
public class Times extends BinaryOperator
{
        public float apply(float a, float b) { return a*b; }
}


public class Array extends Expr
{
    float data[];
    int length;

    public Array(int n)
    {
        data   = new float[n];
        length = n;
    }
    public float eval(int i)
    {
        return data[i];
    }
    public void set(int i, float val)
    {
        data[i] = val;
    }

    public void assign(Expr e)
    {
        for ( int i = 0; i < length; ++i)
            data[i] = e.eval(i);
    }
}

public class Test
{
    public static void main(java.lang.String[] args)
    {
        int n = 10000;

        Array w = new Array(n);
        Array x = new Array(n);
        Array y = new Array(n);
        Array z = new Array(n);

        // initialize x, y, z

        // w = x + y + z
        w.assign(x.plus(y.times(z)));
}

The problem is: there is no partial evaluation in Java compilers


w = x + y * z

To evaluate each e.eval(i)

- 6 virtual function calls

- 3 bound checks

- numerous pointer indirections

dead slow


Lunar compiler:

Java -> intermediate form -> partial evaluation -> C code


JVM                         JavaTran        ExpressionTemplates
==================          ========        ===================
Lunar                       33.1            2.4
Sun Hotspot 1.3 (JIT)        1.4            0.4
Transvirtual Kaffe           4.3            0.7

Mflops/s  n=1000

Bubble Sort

Here we present an other expression template: bubble sort. Bubble sort is very inefficient for large N, but quite reasonable for small N.


inline void swap( int& a, int& b)
{
    int temp = a;
    a = b;
    b = temp;
}
//
// bubble sort is very inefficient for large N,
// but quite reasonable for small N
//
void bubbleSort( int *data int N)
{
    for ( int i = N-1; i > 0; --i)
        for ( int j = 0; j < i; ++j)
            if ( data[j] > data[j+1] )
                swap( data[j], data[j+1]);
}
//
// the inline version for N=3
//
inline void bubbleSort3( int *data)
{
    int temp;

    if ( data[0] > data[1] )
    {
        temp = data[0]; data[0] = data[1]; data[1] = temp;
    }
    if ( data[1] > data[2] )
    {
        temp = data[1]; data[1] = data[2]; data[2] = temp;
    }
    if ( data[0] > data[1] )
    {
        temp = data[0]; data[0] = data[1]; data[1] = temp;
    }
}
//
// we had two loops. To reduce the loops define bubble sort recursively
//
void bubbleSort( int *data, int N)
{
    for ( int j = 0; j < N-1; ++j)
        if ( data[j] > data[j+1] )
                swap( data[j], data[j+1]);

    if ( N > 2 )
        bubbleSort( data, N-1);
}
//
// Now the sort consists of a loop and a recursive call to itself
// this is simple to implement with recursive templates
//
template <int N>
struct IntBubbleSort
{
    static inline void sort(int *data)
    {
        IntBubbleSortLoop<N-1,0>::loop(data);
        IntBubbleSort<N-1>::sort(data);
    }
}

template <>
struct IntBubbleSort<1>
{
    static inline void sort(int *data) { }
}

//
//  IntBubbleSortLoop<N-1,0>::loop(data) will replace the for loop in j
//  and then makes a recursive call to itself
//
//  For N=4 this will be the effect:
//
static inline void IntBubbleSort<4>::sort(int *data)
{
    IntBubbleSortLoop<3,0>::loop(data);
    IntBubbleSortLoop<2,0>::loop(data);
    IntBubbleSortLoop<1,0>::loop(data);
}

//
//  the first template argument 3, 2, 1 plays the role of i in the original version
//  the second is equivalent j
//
template <int I, int J>
class IntBubbleSortLoop
{
private:
    enum { go = ( J <= I-2 ) };
public:
    static inline void loop(int *data)
    {
        IntSwap<J, J+1>::compareAndSwap(data);
        IntBubbleSortLoop< go ? I : 0, go ? (J+1) : 0 >::loop(data);
    }
};

template <>
class IntBubbleSortLoop<0,0>
{
public:
    static inline void loop(int *)  { }
};
//
// writing the terminal case of the recursion is a bit more difficault,
// because we have two template parameters. The solution: both should
// revert to 0, when the base case is reached.
// We use a loop flag (go), and when go is false, we set parameters to 0
//
// For N=4 here is the expansion:
//
static inline void IntBubbleSort<4>::sort(int *data)
{
    IntSwap<0,1>::compareAndSwap(data);
    IntSwap<1,2>::compareAndSwap(data);
    IntSwap<2,3>::compareAndSwap(data);
    IntSwap<0,1>::compareAndSwap(data);
    IntSwap<1,2>::compareAndSwap(data);
    IntSwap<0,1>::compareAndSwap(data);
}
//
// the last remaining definition is IntSwap<I,J>::compareAndSwap(..)
//
template <int I, int J>
class IntSwap
{
public:
    static inline void compareAndSwap(int *data)
    {
        if ( data[I] > data[J] )
            swap( data[I], data[J]);
    }
}
//
// the swap() is the original
//
//
// For N=4 here is the expansion:
//
static inline void IntBubbleSort<4>::sort(int *data)
{
    if ( data[0] > data[1]) swap( data[0], data[1]);
    if ( data[1] > data[2]) swap( data[1], data[2]);
    if ( data[2] > data[3]) swap( data[2], data[3]);
    if ( data[0] > data[1]) swap( data[0], data[1]);
    if ( data[1] > data[2]) swap( data[1], data[2]);
    if ( data[0] > data[1]) swap( data[0], data[1]);
}