O(1) sum_of_multiples() in Rust

I had been working mostly in Scala for a while, then took a diversion into Swift and Objective C.  I wanted to learn another language after that, and had all but decided on Clojure.  But Rust kept nagging at me — there was something about it.  Perhaps I’ll blog more about that later.

So I watched some videos, then read the book, and then started the Rust track at Exercism.io.  Nice site.  One of the exercises is about generating the sum of all multiples under some limit for a given list of factors.  So for example, if the factors are 3, 5, and 7, then the sum of the multiples under 12 is 3 + 5 + 6 + 7 + 9 + 10 = 40.

A naive solution is pretty straightforward:

pub fn naive_sum_of_multiples(limit: u32, factors: &[u32]) -> u32 {
    (1..limit).filter( |&i|
        factors.iter().any(|&j| i % j == 0)
    ).sum()
}

It just iterates through all of the eligible numbers, picking out any that are evenly divisible by any of the factors, and sums them.  Rust’s iterators make that look pretty similar to a lot of other modern languages, like Scala.

But this solution is slow.  I mean really slow — at least compared to what’s possible.  As limit grows, the execution time increases linearly; that is, this is O(n).  (With respect to the limit, that is — we’re ignoring the (very real) impact of the number of factors on execution time. Note that in Exercism’s tests, there are generally two factors, max three, so for simplicity let’s focus on the limit here.)

Well, O(n) doesn’t sound so bad, does it?  But here’s the thing: it could be O(1).

What???  O(1)?  That would mean that you aren’t doing any more work if you increase the limit — and thus the numbers summed — by a factor of 100.  That can’t be right!  But yes, it is.

Here is the key observation.  The instant you see it, your brain will likely jump halfway to the solution.  The sum of (for example)

3 + 6 + 9 + 12 + 15 + 18 + 21

is just

3 * ( 1 + 2 + 3 + 4 + 5 + 6 + 7)

Got it yet?  Remember that the sum of the numbers from 1 to n is simply the value of n * (n + 1) / 2 — actually summing the numbers is unnecessary!

That’s not all there is to it, though. If asked to sum the multiples of 3 and 5 less than 1000, we can use the above technique to sum up the multiples of 3 and the multiples of 5 and then add them together, but we will have counted the multiples of 15 twice; we have to subtract those out.

And that’s a little trickier to do that than it sounds. First, what you really need to subtract out are the multiples of the least common multiple (lcm) of the two numbers, not their product. So, for example, if asked to sum the multiples of 6 and 15, we need to subtract off the multiples of 30 (not 90). The lcm of two numbers is their product divided by their greatest common divisor (gcd).

Also, we need to do this for an arbitrarily long list of numbers, so consider what happens if we are asked to sum the multiples of 4, 6, and 10:

  • First sum the multiples of 4.
  • Then add in the multiples of 6, but subtract the multiples of lcm(4, 6) = 12.
  • Then add in the multiples of 10, but subtract the multiples of lcm(4, 10) = 20 and the multiples of lcm(6, 10) = 30.

But oops, now we have gone the other way, subtracting off the multiples of 20 and 30 in common (60, 120, …) twice, and our result is too low, so we’ll have to add those back in. And if there were multiple corrections at that level (i.e. if we were given a larger list of numbers), we’d have to subtract their elements in common, and so on ad infinitum. At every step we have to take care not to add or subtract the same numbers twice.

That sounds like a pain, but using recursion it’s actually fairly straightforward.  In the following Rust code, I’ve changed the API a bit from the Exercism problem.  First, the integers are u64, so that we can use much bigger limits.  And secondly, in this case we’ll sum all multiples up to and including limit.  It’s an arbitrary choice anyway, and doing it this way will save us a small step for clarity.

pub fn fast_sum_of_multiples(limit: u64, factors: &[u64]) -> u64 {
  fn lcm(a: u64, b: u64) -> u64 { a*b / gcd(a,b) }
  fn gcd(a: u64, b: u64) -> u64 { if b == 0 {a} else { gcd(b, a%b) } }
  fn sum_from_ix(i: usize, limit: u64, factors: &[u64]) -> u64 {
    if i == factors.len() {  // we've processed all factors
      0
    } else {
      let factor = factors[i];
      let n = limit / factor;  // # of multiples of factor to sum
      let sum_of_multiples_of_factor = factor * (n*(n+1)/2);
      let new_factors: Vec<_> = factors[..i].iter()
        .map(|&prev_factor| lcm(prev_factor, factor))
        .filter(|&factor| factor <= limit)
        .collect();
      let sum_of_previously_seen_multiples_of_factor =
        sum_from_ix(0, limit, &new_factors[..]);
      let sum_of_multiples_of_rest_of_factors =
        sum_from_ix(i+1, limit, factors);
      sum_of_multiples_of_factor
        - sum_of_previously_seen_multiples_of_factor
        + sum_of_multiples_of_rest_of_factors
    }
  };
  sum_from_ix(0, limit, factors)
}

This is not too far from what a Scala solution would look like, although I think the Scala solution is a bit more readable.  In part this owes to Scala’s persistent List data structure, which was a little cleaner to work with here than a Rust Vec.  Also, Scala’s implicits make it unnecessary to make two calls that are explicit in the Rust version: iter() and collect().  Oh, and in Rust functions cannot close over variables; only closures can do that, but we couldn’t use a closure since in Rust closures cannot be recursive.  That forced us to explicitly include limit in the argument list of sum_from_ix().  Ampersands here and there in closures.  Semicolons.  These are little things, but collectively noticeable.

def fastSumOfMultiples(limit: Long, factors: List[Long]): Long = {
  def lcm(a: Long, b: Long) = a*b / gcd(a,b)
  def gcd(a: Long, b: Long): Long = if (b == 0) a else gcd(b, a%b)
  def sumOfMults(factors: List[Long], prevFactors: List[Long] = Nil): Long =
    factors match {
      case Nil => 0
      case factor::rest =>
        val n = limit / factor  // # of multiples of factor to sum
        val sum_of_multiples_of_factor = factor * (n*(n+1)/2)
        val sum_of_previously_seen_multiples_of_factor =
          sumOfMults(prevFactors.map(lcm(_, factor)).filter(_ <= limit))
        val sum_of_multiples_of_rest_of_factors =
          sumOfMults(rest, factor::prevFactors)
        sum_of_multiples_of_factor -
           sum_of_previously_seen_multiples_of_factor +
           sum_of_multiples_of_rest_of_factors
    }
  sumOfMults(factors)
}

But really, the difference is not so stark.  And the Rust version does not need a garbage collector — amazing!

On to the next exercise….

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.