Multiple dispatch over covariant functions

Today we have another guest post by Alfredo Correa. In his previous article, Functions of variants are covariant, Alfredo described the concept of a covariant function, that is, how a function return type can depend on the type of input parameters at runtime. In this article he shows how the concept and implementation can be generalized to covariant functions of multiple arguments, effectively achieving runtime multiple dispatch. A prototype implementation is available at the [Boost.]Covariant library.

Introduction

A covariant function is a particular kind of function that maps a sum type into the same or another sum type, while at the same time inducing a (one-to-one or many-to-one) map in the individual input and the output alternative types. A covariant function can be seen as the “sum” of many individual functions applying to specific types (an overload set in C++). Although sum types are well represented in C++ by the std::variant template types, there is currently no facility provided in the standard to represent a covariant function. The standard std::visit allows to select between different functions based on the runtime type held by the std::variant input but, since the return type is fixed, the runtime options cannot propagate through function applications.

In the previous article we saw how, by implementing a few smart function objects and some metaprogramming trickery, an overload set (a function that can act on many types statically) can be converted naturally into a covariant function, allowing to propagate variant type information at runtime in a functional form. In this example, the runtime type information contained in v effectively propagates into the result w:

std::variant<int, double> v = 1.2; // a *runtime* double
auto w = covariant(
  [](int    i){return i + 1;},
  [](double d){return d + 1;}
)(v);
// w is of type std::variant<int, double> but happens to contain a *runtime* double
assert( std::get<double>(w) == 1.2 + 1 );

The concept of a covariant function can be generalized to multiple arguments. In fact, the real power of covariant functions is fully achieved when multiple arguments participate in the determination of the return type. This is sometimes called multiple dispatching. For example, here we would like the runtime type information contained in v1 and v2 propagate into w.

std::variant<int, double> v1 = 1.2; // runtime double (decided at runtime)
std::variant<int, double> v2 = 3;   // runtime int (decided at runtime)
auto covariant_sum = covariant(     // a multiple-argument covariant use here 
  [](int a   , int    b)->int   {return         a +        b ;},
  [](double a, int    b)->double{return         a + double(b);},
  [](int a   , double b)->double{return double(a) +        b ;},
  [](double a, double b)->double{return         a +        b ;}
);
auto w = covariant_sum(v1, v2);
assert( std::get<double>(w) == double(1.2 + 3) );

In the following, we will see how to implement and use such general covariant function, with help of a surprisingly small amount of code.

Implementation

Multiple dispatching is an old technique that is inherent to some interpreted languages (notoriously Julia and Python) but is hard to achieve using C++’s virtual inheritance and classic runtime polymorphism. Historically, multiple dispatching refers mainly to side effects (function behavior) rather than the covariance of the return type. In the following section we will see how to implement multiple covariant functions.

Metaprogramming section

As we saw in the previous article, the main difficulty of the implementation of (single argument) covariant functions is the computation of the possible return types. The strategy back there was to iterate over all possible input types (given a certain variant input and an overload function set) and give a list of possible output types contained in a return variant. Additionally, since the model of variant we were using was implicitly that of an unordered sum type, we decided to remove the duplicate output types.

The case of multiple inputs is no different, except in that the first step requires to iterate over a “product set” of inputs (choose one of many possibilities for each input). For example, if we have three input variants with two possible runtime types each, we have the product set (double, int, string)x(double, int)x(string, char) with 12 possible combinations (3x2x2) of inputs and (at most) 12 different return types. It is easy for this to result in a combinatorial explosion of cases.

Such combination of (i) iteration over input cases, (ii) return type calculation, and (iii) duplicate removal must be performed during compilation and therefore required certain knowledge of template meta-programming as provided by a template metaprogramming library, such as Boost.MPL (used here) or the more modern counterpart Boost.MP11.

Combinatorial input cases

Calculating product sets is not part of Boost.MPL but, fortunately, an implementation called combine_view credited to Andrea Rigoni exists. I am not going to describe it here for lack of space but it effectively allows the following compile-time calculation:

using product_set = combine_view<
boost::mpl::vector<
    boost::mpl::list<double, int, std::string>, 
    boost::mpl::list<double, int>,
    boost::mpl::list<std::string, char>
>>::type;       
static_assert( boost::mpl::size<product>::value == 12 );

We could in principle fully utilize this explosion of combinations and customize each return-type case from the overload set, but it is likely that in practice the space of combinations will project into fewer types when applying a concrete set of possible functions, like in the example in the previous section.

Calculating return types

Once we have all the input combinations, we have to transform it into all the possible output types resulting from the application of the overload set. This is done by generalizing the apply metafunction to the list of multiple-argument combinations defined above:

template<class On, class Args>
struct applyn{
    template<class> struct aux;
    template<std::size_t... Is> struct aux<std::index_sequence<Is...>>{
        using type = decltype(std::declval<On>()(std::declval<typename bmp::at_c<Args, Is>::type>()...));
    };
    using type = typename aux<std::make_index_sequence<bmp::size<Args>::value>>::type;
};

applyn takes an overload function type On and a sequence of types (one of the combination of types above) and gives back the return type of such overload.

Remove duplicates

Finally, we can use the fold/transform_view we applied in the “unary” covariant version in the previous article to make the result unique (remove duplicates if they exist):

template<class On, class... Seqs> 
using results_of_setn_t = 
    typename bmp::fold<
        typename bmp::transform_view<
            typename bmp::combine_view<
                bmp::vector<Seqs...>
            >::type
            ,
            applyn<On, bmp::_>
        >::type,
        bmp::set<>,
        bmp::insert<bmp::_1, bmp::_2>
    >::type
;

Don’t forget non-variant arguments

std::visit is a very powerful function that, when applied to variant types, can select a function implementation from an overload set (or “visitor”). (As a matter of fact, it is technically the only fundamental function that can be statically applied to a raw std::variant type.) The main limitation to overcome here is that the overload set must have a single return type to be a valid visitor.

There is, however, another practical limitation, that is that the function std::visit can only be applied to std::variant types. This is not a big deal when there is only one input argument, as the single argument can be converted into a trivial variant or the visit protocol is not needed at all. However, this limits a lot the applicability of std::visit in generic code with multiple arguments, as sometimes not all input parameters are necessarily std::variant, having static types. For example, this is a hard error in the standard std::visit, even when the function some_visitor could in principle apply to non-variant input.

std::visit(some_visitor, std::variant<double, int>(1.2), 42); // error: 42 (int) is not a variant

Non-variant arguments can be always transformed into variants, although that requires a copy and manual coding, and might have non-zero runtime cost.

std::visit(some_visitor, std::variant<double, int>(1.2), std::variant<int>(42)); // ok, but not optimal

A better alternative could be to create a new visit protocol that accepts non-variants. Recursive use of lambdas can help create a function that “pivots” over subsequent non-variant arguments.

template<class V, class T, class... Ts>
auto pivot(V&& w, T&& t, Ts&&... ts){
    return pivot(
        [&](auto&&... ts2){return std::forward<V>(w)(std::forward<T>(t), std::forward<decltype(ts2)>(ts2)...);}, 
        std::forward<Ts>(ts)...
    );
}

template<class V, class... Vs, class... Ts>
auto pivot(V&& w, std::variant<Vs...> const& v, Ts&&... ts){
    return visit(
        [&](auto&& vv){return pivot(std::forward<V>(w), std::forward<decltype(vv)>(vv), std::forward<Ts>(ts)...);}, 
        v
    );
}

pivot is a natural generalization of std::visit for mixed variant and non-variant input and it is a drop-in replacement for std::visit.

pivot(some_visitor, std::variant<double, int>(1.2), 42); // ok (for a reasonable some_visitor)

This is an nice-to-have feature that later facilitates the application of generic covariant functions to arguments that are not variant.

Results

Finally, we put it all together and we add to our wrapper of overload sets, which gives the capability to make a covariant function with multiple arguments:

template<class... Fs>
struct covariant : overload<Fs...>{
    covariant(Fs... fs) : overload<Fs...>(fs...){}
    template<class... Ts, typename = decltype(overload<Fs...>::operator()(std::declval<Ts>()...))> 
    decltype(auto) call(Ts&&... ts) const{
        if constexpr(std::is_same<decltype(overload<Fs...>::operator()(std::forward<Ts>(ts)...)), void>{})
            return overload<Fs...>::operator()(std::forward<Ts>(ts)...), std::monostate{};
        else
            return overload<Fs...>::operator()(std::forward<Ts>(ts)...);
    }
    template<
        class... Variants,
        class Ret = detail::variant_of_set_t<
            detail::results_of_setn_t<
                overload<Fs...> const&, 
                detail::variant_types_list_t<Variants>...
            >
        >
    >
    Ret operator()(Variants const&... vs){
        return pivot([&](auto&&... es)->Ret{return call(es...);}, vs...);
    }
};

Notes: We are using pivot, defined earlier, as a replacement for std::visit which allows variant and non-variant input. We adopt the convention that detail::variant_types_list_t = mpl::list when T is not an std::variant. Other names were defined in the previous article.

Usage

In this example, a custom defined covariant function takes two “arithmetic numbers”, which can be int, double or complex, and gives the result of the sum (in the most natural domain). Since the function has two arguments and each argument has three cases, there are at most 9 overloads.

using complex = std::complex<double>;
variant<int, double, complex> v1 = 3.14;
variant<int, double, complex> v2 = complex{1., 2.};
auto sum_covariant = covariant(
    [](int     i1, int     i2){return i1 + i2;}, 
    [](int     i1, double  d2){return i1 + d2;},
    [](int     i1, complex c2){return double(i1) + c2;},
    [](double  d1, int     i2){return d1 + i2;},
    [](double  d1, double  d2){return d1 + d2;},
    [](double  d1, complex c2){return d1 + c2;},
    [](complex c1, int     i2){return c1 + double(i2);},
    [](complex c1, double  d2){return c1 + d2;},
    [](complex c1, complex c2){return c1 + c2;}
);
auto w = sum_covariant(v1, v2);

This implementation of covariant function allows to convert any overload set into a function that is covariant on multiple arguments where the combinatorial return cases are automatically handled. Individual arguments can be variant or non-variant.

The implementation of the overload set code is still left to the user, which seems to demand the implementation of a combinatorial set of functions (9 C++ functions or lambdas above). While the combinatorial explosion of return types is handled by the metaprogramming part of the library, the combinatorial explosion of the function overloads is not. Fortunately, template functions can help here if we find patterns in the overload set.

Taming combinatorial overload sets

Although variant types don’t define any hierarchy of the underlying types, it is likely that these alternative types in the std::variant fulfill common concepts. The idea to overcome this combinatorial explosion is to find common code in the different implementations and delegate it to templates (or lambdas with deduced arguments), effectively combining generic programming with runtime dispatching.

A first property we can use to simplify the code is to notice that the operation is symmetric.
We can symmetrize the operation and therefore save the body of 3 (out of 9) functions, by adding a layer of overloads.

auto symmetric_sum_aux = overload(
    [](int     i1, int     i2){return i1 + i2;}, 
    [](int     i1, double  d2){return i1 + d2;},
    [](int     i1, complex c2){return double(i1) + c2;},
    [](double  d1, double  d2){return d1 + d2;},
    [](complex c1, double  d2){return c1 + d2;},
    [](complex c1, complex c2){return c1 + c2;}
); // only 6 non-symmetric cases implemented
auto sum_covariant = covariant(
    symmetric_sum_aux, [&](auto t1, auto t2){return summetric_sum_aux(t2, t1);}
);

Another route for simplification is to observe that only 2 of the original functions have common code and can be handled as an exceptional case. Note above that the code is mostly common (except in the case of integer and complex summation which is not handled because of a quirk in the Standard Library.)

variant<int, double, complex> v1 = 3.14;
variant<int, double, complex> v2 = complex{1., 2.};
auto sum_covariant = covariant(
    [](auto    n1, auto    n2){return n1 + n2;}, 
    [](int     i1, complex c2){return double(i1) + c2;},
    [](complex c1, int     i2){return c1 + double(i2);}
);
auto result = sum_covariant(v1, v2);

Although not always possible, an ideal covariant function might be an overload set composed by a single template function, generaly a single template lambda (lambda with all auto parameters). This is sometimes possible when all the alternative types inside the input variants share common concepts (for example all are arithmetic).

Discussion

Variant variable and covariant functions should be used only when true runtime alternatives arise, e.g. read from a configuration file or runtime user input of type information is provided. Variant types implement a value-based runtime polymorphism, while covariant functions allow to operate and return over these variants. The advantage of using std::variants is that, when they are necessary at all, they can tie runtime decisions to specific objects in our program and can remain isolated from other static code in the program. Covariant functions allow keeping this abstraction of runtime decision encapsulated, allowing the clear demarcation of static and dynamic code. Despite the example code presented here, which is only for illustration purposes, I wouldn’t recommend the use of variant for low level numeric information. For example, by having a large array of numeric variants (e.g. given std::vector<std::variant>; if at all, it would be preferable a variant of arrays std::variant<std::vector, std::vector, std::vector>). The higher the level at which the variant code lives, the better to avoid the penalty of runtime cost.

Reinventing interpreted code in C++

Applying covariant functions, just like applying visitors, has definite performance costs in terms of repeated runtime decisions. These decision points grow with the number of arguments and the possible executions path grows exponentially as the number of combinations. Therefore, there is a cost in using this design excessively. In fact, an abuse of this technique could turn into a situation similar to that of interpreted languages in which all runtime (dynamic) types must be checked almost constantly when operating on a variable.

Another corner-cutting situation that can arise in runtime function application involves leaving certain cases undefined (either because certain combinations do not make sense or because the implementation is not yet ready). Undefined cases can simply “do nothing”, throw exceptions or even terminate.

auto sum_covariant = covariant(
    [](auto    n1, auto    n2){return n1 + n2;}, 
    [](int       , complex   ){throw std::runtime_error{"not implemented1"};},
    [](complex   , int       ){throw std::runtime_error{"not implemented2"};}
);
auto result = sum_covariant(v1, v2);

This is were the maintenance cost of runtime types becomes evident. If one or more cases are left explicitly unimplemented then the code can fail at runtime (just like with some interpreted languages) and can be logically problematic by generating a null variant state (e.g. std::monostate).

Related work

A prototype implementation of the code described here is available at the [Boost.]Covariant library.

While writing this article in the last year I became aware of a similar efforts in the argot library and an example in the Boost.MP11 library. Argot is a library with the more ambitious goal of applying function overload on variants (with the same effect as here), tuples (producing corresponding tuples of output) and optionals (propagating the null state). Boost.MP11 is a modern replacement and extension of Boost.MPL; for example, it has a product-generating function called mp_product_q (similar combine_view above).

Previous Post
Next Post

Leave a Reply

Your email address will not be published. Required fields are marked *