//          Copyright (C) Hal Finkel 2009.
// Distributed under the Boost Software License, Version 1.0.
//    (See accompanying file LICENSE_1_0.txt or copy at
//          http://www.boost.org/LICENSE_1_0.txt)

#include <complex>
#include <iostream>
#include <boost/array.hpp>
#include <boost/typeof/typeof.hpp>
#include <boost/mpl/int.hpp>
#include <boost/proto/core.hpp>
#include <boost/proto/context.hpp>

#include <ctime>

namespace mpl = boost::mpl;
namespace proto = boost::proto;
using proto::_;

struct lazy_complex_domain;

template<typename Size = std::size_t>
struct lazy_subscript_context
{
    lazy_subscript_context(Size subscript)
      : subscript_(subscript)
    {}

    template<typename Expr, typename Tag = typename Expr::proto_tag>
    struct eval
      : proto::default_eval< Expr, lazy_subscript_context >
    {};

    template<typename Expr>
    struct eval<Expr, proto::tag::terminal>
    {
        typedef typename proto::result_of::value<Expr>::type::value_type result_type;

        result_type operator ()( Expr const & expr, lazy_subscript_context & ctx ) const
        {
            return proto::value( expr )[ ctx.subscript_ ];
        }
    };

    template<typename Expr>
    struct eval<Expr, proto::tag::multiplies>
    {
    private:
        static Expr &s_expr;
        static lazy_subscript_context &s_ctx;

    public:
        typedef BOOST_TYPEOF(proto::eval(proto::left(s_expr), s_ctx)) result_type;

        result_type operator()(Expr const & expr, lazy_subscript_context & ctx ) const
        {
            lazy_subscript_context rctx(0), ictx(1);
            BOOST_AUTO(a, proto::eval(proto::left(expr), rctx));
            BOOST_AUTO(b, proto::eval(proto::left(expr), ictx));
            BOOST_AUTO(c, proto::eval(proto::right(expr), rctx));
            BOOST_AUTO(d, proto::eval(proto::right(expr), ictx));
            
            return !ctx.subscript_ ?
                a * c
              - b * d :
                a * d
              + b * c;
        }
    };

    template<typename Expr>
    struct eval<Expr, proto::tag::divides>
    {
    private:
        static Expr &s_expr;
        static lazy_subscript_context &s_ctx;

    public:
        typedef BOOST_TYPEOF(proto::eval(proto::left(s_expr), s_ctx)) result_type;

        result_type operator()(Expr const & expr, lazy_subscript_context & ctx ) const
        {
            lazy_subscript_context rctx(0), ictx(1);
            BOOST_AUTO(a, proto::eval(proto::left(expr), rctx));
            BOOST_AUTO(b, proto::eval(proto::left(expr), ictx));
            BOOST_AUTO(c, proto::eval(proto::right(expr), rctx));
            BOOST_AUTO(d, proto::eval(proto::right(expr), ictx));
            BOOST_AUTO(den, c*c + d*d);
            
            return (!ctx.subscript_ ?
                a * c
              + b * d :
                b * c
              - a * d)/den;
        }
    };

    Size subscript_;
};

template<typename Expr>
struct lazy_complex_expr
  : proto::extends<Expr, lazy_complex_expr<Expr>, lazy_complex_domain>
{
    typedef proto::extends<Expr, lazy_complex_expr<Expr>, lazy_complex_domain> base_type;

    lazy_complex_expr( Expr const & expr = Expr() )
      : base_type( expr )
    {}

    template< typename Size >
    typename proto::result_of::eval< Expr, lazy_subscript_context<Size> >::type
    operator []( Size subscript ) const
    {
        lazy_subscript_context<Size> ctx(subscript);
        return proto::eval(*this, ctx);
    }

    typename proto::result_of::eval< Expr, lazy_subscript_context<int> >::type real() const
    {
        lazy_subscript_context<int> ctx(0);
        return proto::eval(*this, ctx);
    }

    typename proto::result_of::eval< Expr, lazy_subscript_context<int> >::type imag() const
    {
        lazy_subscript_context<int> ctx(1);
        return proto::eval(*this, ctx);
    }
};

template< typename T >
struct lazy_complex
  : lazy_complex_expr< typename proto::terminal< boost::array<T, 2> >::type >
{
    typedef T value_type;
    typedef typename proto::terminal< boost::array<T, 2> >::type expr_type;

    lazy_complex( T const & re = T(), T const & im = T() )
    {
        proto::value(*this)[0] = re;
        proto::value(*this)[1] = im;
    }

    lazy_complex( std::complex<T> const & n )
    {
        proto::value(*this)[0] = n.real();
        proto::value(*this)[1] = n.imag();
    }

    T real() const
    {
        return proto::value(*this)[0];
    }

    T imag() const
    {
        return proto::value(*this)[1];
    }

    template< typename Expr >
    lazy_complex &operator += (Expr const & expr)
    {
        return this->plus_assign(proto::as_expr< lazy_complex_domain >(expr));
    }

    template< typename Expr >
    lazy_complex &operator -= (Expr const & expr)
    {
        return this->minus_assign(proto::as_expr< lazy_complex_domain >(expr));
    }

    template< typename Expr >
    lazy_complex &operator *= (Expr const & expr)
    {
        return this->multiplies_assign(proto::as_expr< lazy_complex_domain >(expr));
    }

    template< typename Expr >
    lazy_complex &operator /= (Expr const & expr)
    {
        return this->divides_assign(proto::as_expr< lazy_complex_domain >(expr));
    }

private:
    template< typename Expr >
    lazy_complex &plus_assign(Expr const & expr)
    {
        proto::value(*this)[0] += expr[0];
        proto::value(*this)[1] += expr[1];
        return *this;
    }

    template< typename Expr >
    lazy_complex &minus_assign(Expr const & expr)
    {
        proto::value(*this)[0] -= expr[0];
        proto::value(*this)[1] -= expr[1];
        return *this;
    }

    template< typename Expr >
    lazy_complex &multiplies_assign(Expr const & expr)
    {
	T a = proto::value(*this)[0], b = proto::value(*this)[1],
          c = expr[0],                d = expr[1];

        proto::value(*this)[0] = a*c - b*d;
        proto::value(*this)[1] = b*c + a*d;
        return *this;
    }

    template< typename Expr >
    lazy_complex &divides_assign(Expr const & expr)
    {
	T a = proto::value(*this)[0], b = proto::value(*this)[1],
          c = expr[0],                d = expr[1];
        T den = c*c + d*d;

        proto::value(*this)[0] = (a*c + b*d)/den;
        proto::value(*this)[1] = (b*c - a*d)/den;
        return *this;
    }
};

struct LazyComplexGrammar
  : proto::or_<
        // proto::terminal< _ >
        lazy_complex< /* _ */ double >
      , proto::terminal< boost::array< _, 2 > >
      , proto::plus< LazyComplexGrammar, LazyComplexGrammar >
      , proto::minus< LazyComplexGrammar, LazyComplexGrammar >
      , proto::multiplies< LazyComplexGrammar, LazyComplexGrammar >
      , proto::divides< LazyComplexGrammar, LazyComplexGrammar >
    >
{};

struct lazy_complex_domain
  : proto::domain<proto::generator<lazy_complex_expr>, LazyComplexGrammar>
{};

template <typename T, class charT, class traits>
std::basic_ostream<charT, traits>& operator << (std::basic_ostream<charT, traits> &o, lazy_complex<T> const &n) {
	return o << "(" << n.real() << "," << n.imag() << ")";
}

template <typename CT>
void speed_test() {
    CT v1( 4.0, 1.0 ), v2( 4.0, 2.0 ), v3( 4.0, 3.0 );

    clock_t start = clock();

    double acc = 0;
    for (int i = 0; i < (1 << 24); ++i) {
        CT v4(i+1, -(i+1));
        acc += (
            ((v1+v2+v3+v4)/v2 + (v3+v4)/v3)*((v4+v4+v2+v1)/v4)
        ).real()/1e16;

        acc += ((v1 + v2 + v3 + v4)/v2 + v3).real()/1e16;
        acc += ((v2 * v4)/v4).real()/1e12;
        acc += (v3 / v4).real()/1e2;
    }

    std::cout << acc << std::endl;
    std::cout << "time: " << clock() - start << std::endl;
    std::cout << v1 << std::endl;
}

int main()
{
    {
        BOOST_MPL_ASSERT(( proto::matches< lazy_complex<double>, proto::terminal< boost::array<double, 2> > > ));

        const lazy_complex<double> x;
        BOOST_MPL_ASSERT(( proto::matches< BOOST_TYPEOF(x+x), proto::plus< proto::terminal< boost::array<double, 2> >, proto::terminal< boost::array<double, 2> > > > ));
    }

    speed_test< lazy_complex<double> >();
    speed_test< std::complex<double> >();

    lazy_complex<double> v1( 4.0, 1.0 ), v2( 4.0, 2.0 ), v3( 4.0, 3.0 );

#if 0
    double d1 = ( v2 + v3 )[ 1 ];
    std::cout << d1 << std::endl;
    double d2 = ( v1 * v2 )[ 0 ];
    std::cout << d2 << std::endl;

    v1 += v2 - v3;
    std::cout << '{' << v1[0] << ',' << v1[1]
               << '}' << std::endl;
    std::cout << v1 << std::endl;
#endif

    return 0;
}


