Refactoring Session #2b: Matrix Calculation – Extract Class

Today I’ll pick up where I left last week with the refactoring with @vaughncato‘s inverse matrix multiplication function. 

Last week I covered mostly code smells that were only loosely related to the algorithm used in the function. The one notable exception was the name of the function itself, which describes what that function or algorithm does – at least after the renaming. This time I’ll focus mainly on the algorithm itself.

As usually, you can follow along each step I take on GitHub. Here’s the code I’ll refactor today – it’s what was left last time except a few cleanups and a renaming suggested in last week’s comments. For brevity, I have left out all the helper functions that I won’t touch since they are not part of the central algorithm:

#include <vector>
#include <cmath>
#include <cassert>
#include <iostream>
#include <algorithm>

using std::vector;
using std::cout;

class Matrix {
  typedef vector<float> Row;
  vector<Row> values;
public:
  Matrix(std::initializer_list<vector<float>> matrixValues)
    : values{matrixValues}
  {}

  int rows() const {
    return values.size();
  }
  int cols() const {
    return values[0].size();
  }
  Row& operator[](std::size_t index) {
    return values[index];
  }
  Row const& operator[](std::size_t index) const {
    return values[index];
  }
};

typedef vector<float> Vector;

// Solve y=m*x for x
Vector gaussJordanElimination(Matrix m, Vector y) {
  int n = m.rows();
  assert(n==m.cols());
  vector<int> ref(n);

  for (int i=0;i<n;++i) {
    ref[i] = i;
  }

  for (int row=0; row<n; ++row) {
    // Find a row that has a non-zero value in the current column
    {
      int i = row;
      for (;;++i) {
        assert(i<n);
        if (m[i][row]!=0) {
          break;
        }
      }
      std::swap(m[i], m[row]);
      std::swap(y[i], y[row]);
      std::swap(ref[i], ref[row]);
    }
    {
      // Normalize row to have diagonal element be 1.0
      float v = m[row][row];
      for (int j=row;j<n;++j) {
        m[row][j] /= v;
      }
      y[row] /= v;
    }
    // Make all lower rows have zero in this column
    for (int j=0;j<n;++j) {
      if (j!=row) {
        float v = m[j][row];
        for (int k=row;k<n;++k) {
          m[j][k] -= m[row][k]*v;
        }
        y[j] -= y[row]*v;
      }
    }
  }
  for (int i=0;i<n;++i) {
    std::swap(y[i], y[ref[i]]);
  }
  return y;
}

int main() {
  Matrix m = {
    {1.1, 2.4, 3.7},
    {1.2, 2.5, 4.8},
    {2.3, 3.6, 5.9},
  };

  Vector y = {0.5,1.2,2.3};

  Vector x = gaussJordanElimination(m, y);

  Vector mx = product(m,x);

  print_matrix("m",m);
  print_vector("y",y);
  print_vector("x",x);
  print_vector("m*x",mx);

  float tolerance = 1e-5;

  for (int i=0, n=y.size(); i!=n; ++i) {
    assert(is_near(mx[i],y[i],tolerance));
  }
}

Before we begin

To understand the algorithm and what it does, you might want to have a short look at the Wikipedia page explaining Gauss-Jordan elimination. If you look closely at the code, the three elementary row operations are used here.

One of those operations is the swapping of rows. To get the result vector in the right order, the algorithm has to keep track of the swaps and restore the order in the result. That’s what `ref` is for – it’s filled with the numbers 0 through n-1, swapped alongside the rows, and later used to reorder `y`. So since it just contains the indices of the row to track, I just renamed it to `rowIndices`.

While we’re at renaming things, `n` is not a very telling name. The variable contains the number of rows – so `rowCount` seems a fitting name.

Gauss-Jordan Matrix

GaussJordanMatrixThis is a usual notation for a matrix and a vector together if you want to perform Gauss-Jordan elimination on them. The operations always are done on both simultaneously. You can observe this in the code as well, where every operation on `m` is also done on `y`.

It makes only sense to put the two into their own data structure. Since the `rowIndices` vector is also used for those transformations, I put all three into the same structure:

struct GaussJordanMatrix {
  Matrix m;
  Vector y;
  vector<int> rowIndices;
};


Vector gaussJordanElimination(Matrix m, Vector y) {
  GaussJordanMatrix gaussJordan{std::move(m), std::move(y), {}};
  //... access gaussJordan.m etc.
}

The initialization of the `rowIndices` vector is only a implementation detail. It belongs into a constructor of our new structure. In fact, the whole `rowIndices` vector is only an implementation detail of the algorithm. While we’re at it, let’s replace the manual initialization loop with a standard algorithm:

struct GaussJordanMatrix {
  //...

  GaussJordanMatrix(Matrix matrix, Vector vector)
    : m{std::move(matrix)}, y{std::move(vector)}, rowIndices{}
  { 
    rowIndices.resize(m.rows());
    std::iota(std::begin(rowIndices), std::end(rowIndices), 0);
  }
};

Vector gaussJordanElimination(Matrix m, Vector y) {
  GaussJordanMatrix gaussJordan{std::move(m), std::move(y)};
  //... access gaussJordan.m etc.
}

Factoring out methods

Now what’s next? The central function still does a lot of stuff. As I have written earlier, we can identify all those row operations like swapping to rows etc. The original author also was so kind to write scope blocks with comments what those blocks do. This is a sign that these blocks should be functions. Since we now have our data structure, that’s where those functions should go.

I’ll start at the top with the row count. Calculating it is straight forward, but the assert in the central function does not belong there. So let’s move it into the constructor of our new class.

struct GaussJordanMatrix {
  //...

  GaussJordanMatrix(Matrix matrix, Vector vector)
    : m{std::move(matrix)}, y{std::move(vector)}, rowIndices{}
  { 
    assert(rowCount()==m.cols());

    rowIndices.resize(rowCount());
    std::iota(std::begin(rowIndices), std::end(rowIndices), 0);
  }

  int rowCount() const { return m.rows(); }
};

You might wonder why I didn’t put the earlier renaming of `n` to `rowCount` and the extraction of the function in one step. That is because both are independent steps. In a refactoring session you often do small steps that could be done together, but smaller steps give you more security.

Sometimes you will even make steps that completely annihilate something you have done earlier. This is not a bad thing if that earlier step helped you reason about the code you are working with.

The next step is pretty straight forward: finding an row with a nonzero value in a given column should be a separate function. While at it, I did some renaming:

struct GaussJordanMatrix {
  //...
  int indexOfRowWithNonzeroColumn(int columnIndex) {
    for (int rowIndex = columnIndex; rowIndex < rowCount(); ++rowIndex) {
      if (m[rowIndex][columnIndex]!=0) {
        return rowIndex;
      }
    }
    assert(false);
    return -1;
  }
};

Then we can factor out the operation “swap rows“, followed by “normalize row” which is the “multiplying row with scalar” operation where the scalar is the inverse of the value of the row in a given column.

Returning the vector part of our Gauss Jordan structure in the original order is another function to be factored out. After that I split the remaining inner loop into two functions. One is the subtraction of a scalar multiple of a row from another row. It is called inside the loop which, as the comment points out, uses the subtraction to zero out all other columns.

What remains to do is a little cleanup to remove unnecessary scopes and comments. The central function now looks small and descriptive. It pretty much lists the steps we need to do for a Gauss Jordan elimination:

// Solve y=m*x for x
Vector gaussJordanElimination(Matrix m, Vector y) {
  GaussJordanMatrix gaussJordan{std::move(m), std::move(y)};
  int rowCount = gaussJordan.rowCount();

  for (int row=0; row<rowCount; ++row) {
    int i = gaussJordan.indexOfRowWithNonzeroColumn(row);
    gaussJordan.swapRows(row,i);
    gaussJordan.normalizeRow(row);
    gaussJordan.subtractToZeroInColumn(row);
  }
  return gaussJordan.getVectorInOriginalOrder();
}

The last step I did in this session was to make a class out of our former struct, since it now is no longer a POD but has functionality. We also don’t need access to the data members any more, so we should make them private.

Here is the code after this session, again without the helper functions:

#include <vector>
#include <cmath>
#include <cassert>
#include <iostream>
#include <algorithm>
#include <numeric>

using std::vector;
using std::cout;

class Matrix {
  typedef vector<float> Row;
  vector<Row> values;
public:
  Matrix(std::initializer_list<vector<float>> matrixValues)
    : values{matrixValues}
  {}

  int rows() const {
    return values.size();
  }
  int cols() const {
    return values[0].size();
  }
  Row& operator[](std::size_t index) {
    return values[index];
  }
  Row const& operator[](std::size_t index) const {
    return values[index];
  }
};

typedef vector<float> Vector;

class GaussJordanMatrix {
  Matrix m;
  Vector y;
  vector<int> rowIndices;

public:
  GaussJordanMatrix(Matrix matrix, Vector vector)
    : m{std::move(matrix)}, y{std::move(vector)}, rowIndices{}
  { 
    assert(rowCount()==m.cols());

    rowIndices.resize(rowCount());
    std::iota(std::begin(rowIndices), std::end(rowIndices), 0);
  }

  int rowCount() const {
    return m.rows();
  }

  int indexOfRowWithNonzeroColumn(int columnIndex) {
    for (int rowIndex = columnIndex; rowIndex < rowCount(); ++rowIndex) {
      if (m[rowIndex][columnIndex]!=0) {
        return rowIndex;
      }
    }
    assert(false);
    return -1;
  }

  void swapRows(int i, int j) {
    std::swap(m[i], m[j]);
    std::swap(y[i], y[j]);
    std::swap(rowIndices[i], rowIndices[j]);
  }

  void normalizeRow(int rowIndex) {
    auto& row = m[rowIndex];
    auto diagonalElement = row[rowIndex];
    for (auto& rowEntry : row) {
      rowEntry /= diagonalElement;
    }
    y[rowIndex] /= diagonalElement;
  }

  void subtractRow(int rowIndex, float factor, int fromRowIndex) {
    auto const& row = m[rowIndex];
    auto& fromRow = m[fromRowIndex];
    for (int k=0;k<rowCount();++k) {
      fromRow[k] -= row[k]*factor;
    }
    y[fromRowIndex] -= y[rowIndex]*factor;
  }

  void subtractToZeroInColumn(int masterRowIndex) {
    for (int rowIndex=0;rowIndex<rowCount();++rowIndex) {
      if (rowIndex!=masterRowIndex) {
        float factor = m[rowIndex][masterRowIndex];
        subtractRow(masterRowIndex, factor, rowIndex);
      }
    }
  }

  Vector getVectorInOriginalOrder() {
    Vector v = y;
    for (int i=0;i<rowCount();++i) {
      std::swap(v[i], v[rowIndices[i]]);
    }
    return v;
  }
};

// Solve y=m*x for x
Vector gaussJordanElimination(Matrix m, Vector y) {
  GaussJordanMatrix gaussJordan{std::move(m), std::move(y)};
  int rowCount = gaussJordan.rowCount();

  for (int row=0; row<rowCount; ++row) {
    int i = gaussJordan.indexOfRowWithNonzeroColumn(row);
    gaussJordan.swapRows(row,i);
    gaussJordan.normalizeRow(row);
    gaussJordan.subtractToZeroInColumn(row);
  }
  return gaussJordan.getVectorInOriginalOrder();
}

int main() {
  Matrix m = {
    {1.1, 2.4, 3.7},
    {1.2, 2.5, 4.8},
    {2.3, 3.6, 5.9},
  };

  Vector y = {0.5,1.2,2.3};

  Vector x = gaussJordanElimination(m, y);

  Vector mx = product(m,x);

  print_matrix("m",m);
  print_vector("y",y);
  print_vector("x",x);
  print_vector("m*x",mx);

  float tolerance = 1e-5;

  for (int i=0, n=y.size(); i!=n; ++i) {
    assert(is_near(mx[i],y[i],tolerance));
  }
}

Conclusion

It took me some time to sit down and start refactoring this code. The main reason was that it was hard to get to the bottom of what it did, especially with those shortened variable names. I think it is a little easier to grasp now, even though there still are issues that could be worked on.

This is another lesson we can take from this session: You’re probably never done improving the code. It’s important to find and reduce the pain points and to know when it’s good enough – at least for now.

Facebooktwittergoogle_plusredditlinkedinFacebooktwittergoogle_plusredditlinkedinby feather

7 Comments




  1. alfC

    Next suggested step: remove the naked `for` loops and replace it by STL algorithms, like `transform` and `rotate`.

    Reply
    1. Arne Mertz

      I’d probably refactor the data structures like `Matrix` first, but loops to algorithms definitely should be on every list of refactoring patterns.

      Reply
  2. Vaughn Cato

    That’s looking pretty good. It seems like the gaussJordanElimination function is now suffering from feature envy. Maybe take one more step?

    Reply
    1. Arne Mertz

      Hi Vaughn,
      glad you like it 🙂
      Yo are right, I should have put the function into the `GaussJordanMatrix` class as a last step.

      Reply

Leave a Reply

Your email address will not be published. Required fields are marked *