Project Euler Problem 1, Beaten to Death in Scala

I have been trying to learn some Scala in my spare time, and like many I decided to use the Project Euler problems for practice. Problem 1 is simply to “Add all the natural numbers below one thousand that are multiples of 3 or 5.” Lots of folks have answered this already; the usual approach is to iterate through the numbers below 1000, find the ones we want using the modulo operator, and sum them, e.g. (from here):

val r = ( 1 until 1000 ).view.filter( n => n%3 == 0 || n%5 == 0 ).sum

Incidentally, the call to view() prevents the construction of an intermediate collection containing the numbers 1..999. The view allows you to iterate through the numbers without actually manifesting a collection containing them. You can read more about views here, in the nice writeup about the slick rewrite of Scala’s collections API for 2.8.

What follows are three solutions for a more general statement of the problem, using different Scala language features (for my own edification). They grow increasingly efficient (Fast, Faster, Fastest), the last taking a trivial amount of time compared to the others. If you get bogged down reading this, skip to the Fastest solution — it’s probably the easiest and most interesting.

Fast (Streams)

To begin with, I wanted to solve the problem more generally, by writing a function that takes a limit and an arbitrary list of integers for which we want the multiples under the limit summed. Moreover, I wanted it to be efficient in the sense that if the numbers were large, as in “Add all the natural numbers below 1,000,000 that are multiples of 3137, 2953, or 2789,” we would not do way more calculations than necessary. There are only about a thousand multiples of those integers under a million, so if we examined every integer in the range we would be touching about a thousand times as many numbers as we really needed to.

Instead I wanted to try using Streams, i.e. to generate a Stream of multiples for each of the integers and then merge them together into a single Stream from which I could just takeWhile() the desired numbers and sum() them up. That way not only would we potentially touch far fewer numbers, but we would also do multiplications instead of more expensive modulo operations. Here’s what I came up with:

 
object SumMults1 extends App {

  val ints = args.toList.map(_.toInt)
  println( sumMults( ints.head, ints.tail ) )

  def sumMults( limit: Int, ints: List[Int] ) = {
    // Merge two infinite Streams of Ints, removing duplicates.
    def merge( xs: Stream[Int], ys: Stream[Int] ): Stream[Int] = {
      val x = xs.head
      val y = ys.head
      if      (x<y) x #:: merge( xs.tail, ys      )
      else if (x>y) y #:: merge( xs     , ys.tail )
      else  /* x=y */     merge( xs     , ys.tail )
    }
    val natNums = Stream from 1
    ints.map( n => natNums.map(n*) ).reduceLeft(merge).takeWhile(_<limit).sum
  }
}

The last line is where it all comes together: it creates an infinite Stream of multiples for each of the integers, reduces them to a single merged Stream, takes those up to our limit, and returns the sum. If you have your head around functional programming, this is probably pretty clear and concise.

Anyway, it works fine, e.g.

% scala SumMults1 1000 3 5
233168

However, I found that if I used a large limit, I got an OutOfMemoryError. I puzzled over this for a while, thinking that perhaps #:: was not lazy under some circumstances or that the third call to merge wasn’t being optimized as a tail call or ??? I used the @tailrec annotation to verify that the third call to merge was being optimized (the first two aren’t, but that’s ok since they are evaluated lazily).  I finally found the answer, by Rex Kerr, on StackOverflow:

Streams are lazily evaluated, but they store the results of their evaluations. Thus, unless you want to traverse through them multiple times, you should use Iterator instead.

So the problem was not that we were trying to evaluate too much of the Stream, but simply that we were saving everything we had previously taken from the Stream, which was quite a pile.

UPDATE: Recently the Scala community has begun an effort to flesh out the scaladoc with introductory text, and the results for Stream are very impressive indeed! There it explains that as long as something is holding onto the Stream object, the memoized objects cannot be garbage-collected.

Faster (Iterators)

So, how to do this with Iterators. It isn’t anywhere near as easy as just replacing Stream with Iterator, because there is no #:: operator with which to construct a new Iterator from existing Iterators. Well, I decided to go for broke, changing a bunch of things to make a much more efficient version:

object SumMults2 extends App {

  val ints = args.toList.map(_.toInt)
  println( sumMultiples( ints.head, ints.tail.distinct ) )

  def sumMultiples( limit: Int, ints: List[Int] ) = {
    val mults = mergeOf( ints.sorted.map(Series(_)) ).takeWhile(_<limit)
    ( BigInt(0) /: mults ) (_+_)
  }

  // A generator for an arithmetic series with the specified head/step.
  case class Series( step:Int ) {
    var head = step
    def tail = { head += step; this }
  }

  // Return an Iterator that merges a sorted List of Series.
  def mergeOf( list: List[Series] ) = new Iterator[Int] {
    var l = list
    def hasNext = true
    def next = {
      val l0   =  l.head
      val head = l0.head
      l = insert( l0.tail, l.tail )
      head
    }
  }

  // Insert a Series into a list of Series, maintaining the ordering.
  def insert( s: Series, l: List[Series] ): List[Series] = l match {
    case Nil => List(s)
    case _   => {
      val sHead  =  s.head
      val l0     =  l.head
      val l0Head = l0.head
      if      ( sHead < l0Head )   s :: l
      else if ( sHead > l0Head )  l0 :: insert(  s     , l.tail )
      else   /* sHead = l0Head */  s :: insert( l0.tail, l.tail )
    }
  }

}

This version uses a mutable Series class to generate the multiples of an integer by simple addition, and then creates an Iterator that merges a list of Series objects. The Iterator keeps the list sorted by the heads of its Series elements, so that the next value of the merge is always just the head of the first Series in the List.

This version is far faster than the other, and doesn’t chew up lots of memory. Also, this version uses BigInt to accumulate the results, so that we can use much larger limits:

% scala SumMults2 1000 3 5
233168
% scala SumMults2 1000000000 3137 1753 2789
623638337230623

Fastest (Math)

But there is a far faster approach still, one that makes the above versions look like lumbering brontasauri. Perhaps you have already noticed the crucial point, that (for example):

3 + 6 + 9 + … + 300

can be rewritten

3 * ( 1 + 2 + 3 + … + 100)

The parenthesized part — as Gauss supposedly knew as a young child — is just

100 * 101 / 2

So the sum we want is just three times that. We can use this mathematical relationship to calculate the sum without actually iterating through all of the numbers!

However, it’s not quite that simple. If asked to sum the multiples of 3 and 5 less than 1000, we can use the above technique to sum up the multiples of 3 and the multiples of 5 and then add them together, but we will have counted the multiples of 15 twice; we have to subtract those out.

And that’s a little trickier to do that than it sounds. First, what you really need to subtract out are the multiples of the least common multiple (lcm) of the two numbers, not their product. So, for example, if asked to sum the multiples of 6 and 15, we need to subtract off the multiples of 30 (not 90). The lcm of two numbers is their product divided by their greatest common divisor (gcd).

Also, we need to do this for an arbitrarily long list of numbers, so consider what happens if we are asked to sum the multiples of 4, 6, and 10:

  • First sum the multiples of 4.
  • Then add in the multiples of 6, but subtract the multiples of lcm(4,6) = 12.
  • Then add in the multiples of 10, but subtract the multiples of lcm(4,10) = 20 and the multiples of lcm(6,10) = 30.

But oops, now we have gone the other way, subtracting off the multiples of 20 and 30 in common (60,120,…) twice, and our result is too low, so we’ll have to add those back in. And if there were multiple corrections at that level (i.e. if we were given a larger list of numbers), we’d have to subtract their elements in common, and so on ad infinitum. At every step we have to take care not to add or subtract the same numbers twice.

That sounds like a pain, but using recursion it’s pretty straightforward:

object SumMults3 extends App {

  val ints = args.toList.map( BigInt(_) )
  val limit = ints.head - 1
  println( sumMults( ints.tail ) )

  // The sums of all the multiples, up to a limit, of a list of numbers.
  def sumMults( ints: List[BigInt], priorInts: List[BigInt] = Nil ): BigInt =
    ints match {
      case Nil => 0
      case int::rest =>
        val n = limit / int    // # of elements in series to sum
        int * ( n * (n+1) / 2 ) -
          sumMults( priorInts.map( lcm(int,_) ).filter(_<=limit) ) +
          sumMults( rest, int::priorInts )
    }

  // Least Common Multiple and Greatest Common Divisor.
  def lcm( a: BigInt, b: BigInt ) = a*b / gcd(a,b)
  def gcd( a: BigInt, b: BigInt ): BigInt = if (b==0) a else gcd( b, a%b )

}

The first recursive invocation of sumMults takes care of adjusting for the error. Since it uses itself to calculate the error, any double-counting in the adjustment is also handled correctly.

% scala SumMults3 1000 3 5
233168
% scala SumMults3 1000000000 3137 1753 2789
623638337230623

In this version we use BigInts for everything because at this point we are doing so few operations that it hardly matters. Even for very large numbers which would take the other solutions years to process (if they could even handle numbers that large), the program takes only slightly longer than it takes to initialize the jvm:

% time scala SumMults3 10000000000000000000000000 3 5 7 9
27142857142857142857142857857142857142857142857145
scala SumMults3 10000000000000000000000000 3 5 7 9 0.25s user 0.00s system 90% cpu 0.276 total

Gotta love math.  And recursion.

[other Scala posts]

Advertisements

One Comment

  1. Posted February 10, 2012 at 7:06 pm | Permalink | Reply

    I really enjoyed this problem, particularly the Fastest solution. It reminds me that sometimes there is a huge payoff to understanding a problem thoroughly.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: