Serializing Eigen::Matrix using Cereal library

2020-07-13 23:43发布

问题:

UPDATED: I managed to get it to work after I googled around and read the doxygen comments in code. Problem was that I missed the cast before using resize() method and also not using std::ios::binary for the streams. If you want to do something similar, better check the answer by Azoth.

I am trying to serialize Eigen::Matrix type using Cereal. This is what I have (loosely based on https://gist.github.com/mtao/5798888 and the the types in cereal/types):

#include <cereal/cereal.hpp>
#include <cereal/archives/binary.hpp>
#include <Eigen/Dense>
#include <fstream>

namespace cereal
{
    template <class Archive, class _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols> inline
        typename std::enable_if<traits::is_output_serializable<BinaryData<_Scalar>, Archive>::value, void>::type
        save(Archive & ar, Eigen::Matrix<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols> const & m)
    {
            int rows = m.rows();
            int cols = m.cols();
            ar(make_size_tag(static_cast<size_type>(rows * cols)));
            ar(rows);
            ar(cols);
            ar(binary_data(m.data(), rows * cols * sizeof(_Scalar)));
        }

    template <class Archive, class _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols> inline
        typename std::enable_if<traits::is_input_serializable<BinaryData<_Scalar>, Archive>::value, void>::type
        load(Archive & ar, Eigen::Matrix<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols> const & m)
    {
            size_type size;
            ar(make_size_tag(size));

            int rows;
            int cols;
            ar(rows);
            ar(cols);

            const_cast<Eigen::Matrix<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols> &>(m).resize(rows, cols);

            ar(binary_data(const_cast<_Scalar *>(m.data()), static_cast<std::size_t>(size * sizeof(_Scalar))));
        }
}

int main() {
    Eigen::MatrixXd test = Eigen::MatrixXd::Random(10, 3);
    std::ofstream out = std::ofstream("eigen.cereal", std::ios::binary);
    cereal::BinaryOutputArchive archive_o(out);
    archive_o(test);

    std::cout << "test:" << std::endl << test << std::endl;

    out.close();

    Eigen::MatrixXd test_loaded;
    std::ifstream in = std::ifstream("eigen.cereal", std::ios::binary);
    cereal::BinaryInputArchive archive_i(in);
    archive_i(test_loaded);

    std::cout << "test loaded:" << std::endl << test_loaded << std::endl;
}

回答1:

Your code is nearly correct but has a few mistakes:

You don't need to be making the size_tag since you are serializing the number of rows and columns explicitly. Generally cereal uses size_tag for resizable containers like vectors or lists. Even though the matrix can resize, it makes more sense just to serialize the rows and columns explicitly.

  1. Your load function should accept its parameter by non-const reference
  2. You shouldn't use operator= with the std::ofstream objects
  3. It's better practice to let scoping and RAII handle closing/tearing down of the std::ofstream as well as cereal archives (the binary archive will flush its contents immediately, but in general cereal archives are only guaranteed to flush their contents on destruction)

Here's a version that compiles and produces correct output under g++ and clang++:

#include <cereal/cereal.hpp>
#include <cereal/archives/binary.hpp>
#include <Eigen/Dense>
#include <fstream>

namespace cereal
{
  template <class Archive, class _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols> inline
    typename std::enable_if<traits::is_output_serializable<BinaryData<_Scalar>, Archive>::value, void>::type
    save(Archive & ar, Eigen::Matrix<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols> const & m)
    {
      int32_t rows = m.rows();
      int32_t cols = m.cols();
      ar(rows);
      ar(cols);
      ar(binary_data(m.data(), rows * cols * sizeof(_Scalar)));
    }

  template <class Archive, class _Scalar, int _Rows, int _Cols, int _Options, int _MaxRows, int _MaxCols> inline
    typename std::enable_if<traits::is_input_serializable<BinaryData<_Scalar>, Archive>::value, void>::type
    load(Archive & ar, Eigen::Matrix<_Scalar, _Rows, _Cols, _Options, _MaxRows, _MaxCols> & m)
    {
      int32_t rows;
      int32_t cols;
      ar(rows);
      ar(cols);

      m.resize(rows, cols);

      ar(binary_data(m.data(), static_cast<std::size_t>(rows * cols * sizeof(_Scalar))));
    }
}

int main() {
  Eigen::MatrixXd test = Eigen::MatrixXd::Random(10, 3);

  {
    std::ofstream out("eigen.cereal", std::ios::binary);
    cereal::BinaryOutputArchive archive_o(out);
    archive_o(test);
  }

  std::cout << "test:" << std::endl << test << std::endl;

  Eigen::MatrixXd test_loaded;

  {
    std::ifstream in("eigen.cereal", std::ios::binary);
    cereal::BinaryInputArchive archive_i(in);
    archive_i(test_loaded);
  }

  std::cout << "test loaded:" << std::endl << test_loaded << std::endl;
}


回答2:

Based on @Azoth answer (whom I would like to give the whole credit, anyway), I improved the template a bit to

  • work also for Eigen::Array (rather than just Eigen::Matrix);
  • not serialize compile-time dimensions (that makes quite some storage difference for e.g. Eigen::Vector3f).

This is the result:

namespace cereal
{
  template <class Archive, class Derived> inline
    typename std::enable_if<traits::is_output_serializable<BinaryData<typename Derived::Scalar>, Archive>::value, void>::type
    save(Archive & ar, Eigen::PlainObjectBase<Derived> const & m){
      typedef Eigen::PlainObjectBase<Derived> ArrT;
      if(ArrT::RowsAtCompileTime==Eigen::Dynamic) ar(m.rows());
      if(ArrT::ColsAtCompileTime==Eigen::Dynamic) ar(m.cols());
      ar(binary_data(m.data(),m.size()*sizeof(typename Derived::Scalar)));
    }

  template <class Archive, class Derived> inline
    typename std::enable_if<traits::is_input_serializable<BinaryData<typename Derived::Scalar>, Archive>::value, void>::type
    load(Archive & ar, Eigen::PlainObjectBase<Derived> & m){
      typedef Eigen::PlainObjectBase<Derived> ArrT;
      Eigen::Index rows=ArrT::RowsAtCompileTime, cols=ArrT::ColsAtCompileTime;
      if(rows==Eigen::Dynamic) ar(rows);
      if(cols==Eigen::Dynamic) ar(cols);
      m.resize(rows,cols);
      ar(binary_data(m.data(),static_cast<std::size_t>(rows*cols*sizeof(typename Derived::Scalar))));
    }
}