Functions of Variants are Covariant

Contents

Today I am happy to present a guest post by Alfredo Correa about covariant visitors for std::variant. Alfredo works at Lawrence Livermore National Laboratory where he uses C++ to develop physics simulations tools.

Introduction

Sum types have a range of values that is the sum of the ranges of its parts. std::variant is the model representation of sum types in C++.

For example std::variant can hold an integer value (int state) or a double value (double state). The use of variant types provides support for polymorphism while maintaining value semantics.

There are only a few intrinsic functions that can be applied directly to an std::variant instance in C++; basically, only functions that probe or extract their current type state and value. Simple C++ functions over its component states cannot be applied directly to the variant since the type information needs to be probed before calling the corresponding function over the correct type.

Specific C++ functions can be applied through visitors. However, standard visitors are static and non-covariant, stopping polymorphism from propagating through function application.

A basic explanation of variants and their visitors can be found here.

(Non-covariant) Functions

The free function std::visit defines a protocol that can be used to apply a corresponding concrete function from a set of overloads or cases. The concrete functions are grouped as a visitor function. A visitor over a variant type is a callable object that is overloaded or can be applied to all the states of the variant.

Visitor class

The prototypical visitor class has several overloads of operator().

struct print_visitor {
  std::ostream& operator()(int a){
    return std::cout << "(int)" << a;
  }
  std::ostream& operator()(double a){
    return std::cout << "(double)" << a;
  }
};

std::variant<double, int> v = 3.14;
std::visit(print_visitor{}, v); // prints "(double)3.14"

The overload can include template functions which can exploit common syntax within the variant set.

On-the-fly visitors

If the function is generic enough and can be used on all the possible variant types, the auxiliary function can be replaced by a single lambda and called on-the-fly:

std::visit([](auto const& e)->std::ostream&{return std::cout << e;}, v) << '\n';

If a single generic lambda is incapable of handling all the cases, a set of lambdas can be grouped automatically by an overload object. std::overload is a proposed addition to the standard library but can be implemented easily:

template <class ...Fs>
struct overload : Fs... {
  template <class ...Ts>
  overload(Ts&& ...ts) : Fs{std::forward<Ts>(ts)}...
  {} 
  using Fs::operator()...;
};
template <class ...Ts>
overload(Ts&&...) -> overload<std::remove_reference_t<Ts>...>;

Visitors can be applied to multiple arguments as well. A basic description of on-the-fly visitors can be found here.

Scope and Restrictions

The free function std::visit and the visitor define a function that can be applied over a variant set.

For a visitor to be valid, it needs to overload all the possible cases of the variant, either by overload resolution, template instantiation or implicit type conversion. Each overload is in principle independent (although it probably makes sense that they all implement a common conceptual operation) but all the overloads have to return a common type. This common return type makes the functions that be applied non-covariant.

For example this is not a valid visitor for std::variant, because the return type is not common to all cases (int and double).

// invalid visitor, (what type is d?)
auto d = std::visit([](auto n){ return n + 1; }, std::variant<int, double>(3.14));

Co-variant functions

A covariant function is one in which the runtime case of the return type varies with the case of the input type(s). Covariant functions are a special case of overloads that return a single variant type based on the concrete return types of the overload functions.

The example above is not a valid visitor and therefore it cannot be applied directly to a variant. However, it can be converted into a valid visitor and therefore also a covariant function by manually predict the possible states of the result and wrap the result into a common variant.

auto next = [](auto n)->std::variant<int, double>{ return n + 1; };
auto d = std::visit(next, std::variant<int, double>(3.14));
assert( std::get<double>(d) = 3.14 + 1.0 );

In this case, the possible output types of the overload set are int and double, therefore the type of d is std::variant.

The problem with this approach is that it doesn’t scale well for more complicated cases; one has to manually keep track of the possible return types of the overload and enforce the variant return type. If the overload is more complicated, it may become harder to account for all the result types. Besides, it is not clear how to handle the special void case. For example:

auto next_overload = overload(
  [](int a)->std::variant<int, double, void??>{ return a + 1; },
  [](double a)->std::variant<int, double, void??>{ return a + 1; },
  [](char)->std::variant<int, double, void??>{ return int(0); },
  [](std::string)->std::variant<int, double, void??>{}
)
std::visit(next_overload, v);

Note that the possible number of return types of the overload is equal to or smaller than the original variant states. If the function takes more than one (variant) argument the number of possible output state multiplies.

Additionally, since void is not a regular type, a variant containing a void type is invalid. Therefore, overloads that return void should be handled separately. A possible design choice is to map the void return to a special monostate, which effectively behaves like a regular void.

Automatic covariant return type deduction

Given an overload set (or callable function) Overload and a list of input types ListInput, in principle, it is possible to deduce the set of all the possible return types ResultSet from all the possible inputs and later construct a variant Out for it.

This is, of course, something that can be computed at the time of compilation, for example with a (meta)function result_set_of:

using Result = results_of_set_t<Overload, ListInput>;

where ListInput is extracted from a variant type:

using ListInput = variant_types_list_t<std::variant<...>>;

From the set of results a new variant type can be made:

using new_variant = variant_of_set_t<Result>;

Implementation

This type deduction requires a certain degree of metaprogramming. Different techniques can be used to implement the type deduction above.

Here we use the Boost Metaprogramming Library, Boost.MPL. The code is not particularly obvious but is simple from the perspective of functional programming:

namespace bmp = boost::mpl;

template<class, class> struct variant_push_back;
template<class... Vs, class T>
struct variant_push_back<std::variant<Vs...>, T> {
  using type = std::variant<
    std::conditional_t<
      std::is_same<T, void>::value
      , std::monostate 
      , T
    >,
    Vs...
  >;
};

template<class Set> 
using variant_of_set_t = 
  typename bmp::fold<
    Set,
    std::variant<>,
    variant_push_back<bmp::_1, bmp::_2>
  >::type;
;
template<class F, class T>
struct apply {
  using type = decltype(std::declval<F>()(std::declval<T>()));
};
template<class O, class Seq> 
using results_of_set_t = 
  typename bmp::fold<
    typename bmp::transform_view<
      Seq,
      apply<O, bmp::_1>
    >::type,
    bmp::set<>,
    bmp::insert<bmp::_1, bmp::_2>
  >::type
;
template<class T> struct variant_types_list
template<class... Ts>
struct variant_types_list<std::variant<Ts...>> {
  using type = bmp::list<Ts...>;
};
template<class T> using variant_types_list_t = typename variant_types_list<T>::type;

Once we are past this metaprogramming interlude, it is relatively straightforward to define a covariant wrapper class that generalizes the concept of an overload in order to produce a covariant function. Note that the main complication is to handle the void return case.

template<class... Fs>
struct covariant : overload<Fs...> {
  covariant(Fs... fs) : overload<Fs...>(fs...) {}
  template<class... Ts, typename = decltype(overload<Fs...>::operator()(std::declval<Ts>()...))> 
  decltype(auto) call(Ts&&... ts) const{
    if constexpr(std::is_same<decltype(overload<Fs...>::operator()(std::forward<Ts>(ts)...)), void>::value) {
      overload<Fs...>::operator()(std::forward<Ts>(ts)...);
      return std::monostate{};
    } else {
      return overload<Fs...>::operator()(std::forward<Ts>(ts)...);
    }
  }
  template<class... Ts, class Ret = variant_of_set_t<detail::results_of_set_t<overload<Fs...> const&, variant_types_list_t<std::variant<Ts...>>>>>
  Ret operator()(std::variant<Ts...> const& v) {
    return std::visit([&](auto&& e)->Ret{ return call(e); }, v);
  }
};

template<class... Fs> covariant(Fs... f) -> covariant<Fs...>;

Result and Conclusion

In the same way that visitors can be applied to variant types and return a single type, a covariant functions can return a new variant type.

This example function gives the next element within the current state of the variant (e.g. double or int), for a non-numeric input (std::string) it returns nothing (void) which is turned into an std::monostate.

std::variant<int, double, std::string> v = 1.2;
auto d = covariant(
  [](int i){ return i + 1; },
  [](double d){ return d + 1; },
  [](auto const&){} // return void otherwise
)(v);
// d is of type std::variant<int, double, std::monostate>
assert( std::get<double>(d) == 1.2 + 1 );

For simplicity, the covariance defined here works only with respect to a single argument.
The power of this technique is that it scales to multiple variant arguments at the cost of a slightly more elaborate metaprogramming code. The concept of a covariant function simplifies the propagation of polymorphic values through the use of functions.

The reference implementation can be found in https://gitlab.com/correaa/boost-covariant.

Previous Post
Next Post

Leave a Reply

Your email address will not be published. Required fields are marked *