Monday, November 24, 2008

Refactoring Imperative Code To Functional Code

I've been refactoring Scala literally for days. It's fantastic how much I've learned over the last year. I knew a little bit about functional programming from doing Lisp in college, but a year and a half ago I couldn't have given you the definition of it.

I decided to tackle some of the most obvious (and ugly) imperative/procedural code in my CPU simulator and turn it into an elegant, functional style.

I'll paste the original code here, then explain it a bit, then show the refactorings one step at a time. In explaining, I will assume that you know at least a bit about how anonymous functions work in Scala.


class EightBitAdder( a: EightBitNumber,
b: EightBitNumber,
carryIn: PowerSource ) {

private val fullAdders: Array[FullAdder] = createFullAdders( a, b, carryIn )
private val output: EightBitNumber = createOutput()

def getOutput: EightBitNumber = output
def getCarryOut: PowerSource = fullAdders(7).getCarryOut


private def createFullAdders( a: EightBitNumber,
b: EightBitNumber,
carryIn: PowerSource ): Array[FullAdder] = {
val fullAdders = new Array[FullAdder](8)
fullAdders(0) = new FullAdder( a(0), b(0), carryIn );
for( i <- 1 until 8 ){
fullAdders(i) = new FullAdder(a(i),b(i), fullAdders(i-1).getCarryOut);
}
fullAdders
}

private def createOutput(): EightBitNumber = {
val out = new Array[PowerSource](8)

var count = 0;
fullAdders.foreach( p => {
out(count) = p.getSumOut
count = count + 1
}
)
new EightBitNumber(out)
}
}


This code was designed to build an 8 bit adder out of two 8 bit numbers (numbers in this particular case aren't much more than arrays of bits) and a carry in bit. This is your typical, ordinary, everyday 8 bit adder that you much have seen in 1950. The adders job is simple, output the result of the two input numbers added up. As with any 8 bit adder, there are 9 outputs, 8 standard output bits, and the carry out/overflow bit.

The adder accomplishes addition in a standard fashion, by chaining 8 full adders (one bit adders) together. The first full adder in the chain uses the given carry in bit as its carry in bit, and subsequent adders in the chain use the carry out of the previous full adder as the carry in. The carry out of the entire adder is simply the carry out of the last full adder in the chain. Here is a picture of one:



Now that we have all the background out of the way, I'll start the refactorings. I'm actually going to go a little bit in reverse order, refactoring the createOutput() method first, because it is substantially easier to refactor than the createFullAdders method.

Simple Refactoring


Lets take another look at createOutput:

private def createOutput(): EightBitNumber = {
val out = new Array[PowerSource](8)

var count = 0;
fullAdders.foreach( p => {
out(count) = p.getSumOut
count = count + 1
}
)
new EightBitNumber(out)
}

This method is returning an EightBitNumber, which is really just an array holding the 8 output bits. The code is terribly imperative, and terribly terrible. The sad thing is, I actually wrote this code, and I'm not just making up a bad example :( Anyway, its overall strategy is pretty clear.

  1. Create an empty, length 8 array
  2. Create a counter object for indexing into the array
  3. Loop over all the full adders (those are already create by the time this method gets called, and well see that in a bit)
  4. For each adder: Put each full adders sum out into the array, and increment the counter.
  5. Finally, create an 8 bit number object using the array.

The first 4 steps above are there simply to create the array to pass to the EightBitNumber object. Now lets take a look at the refactored code:

private val output = new EightBitNumber(fullAdders.map(fa => fa.getSumOut))

Wow! That looks a lot easier - 12 lines down to 1! But...some people might not know what it does, so I'll do my best to explain. The map method, which is a method on all Seq objects (short for Sequence; Array is a subclass), "Returns the list resulting from applying the given function f to each element of this list." An example should help.

Given a=List(1,2,3,4,5) then a.map( i => i * 10 ) returns List(10,20,30,40,50). i * 10 was applied to each element "i" in the list.

In the cpu simulator code above, the call to map has built an Array containing the sumOut of each full adder using the function fa => fa.getSumOut.

Slightly More Difficult Refactoring


With the easier part out of the way, I tackled the more difficult createFullAdders. Let's review the original implementation again.

private def createFullAdders( a: EightBitNumber,
b: EightBitNumber,
carryIn: PowerSource ): Array[FullAdder] = {
val fullAdders = new Array[FullAdder](8)
fullAdders(0) = new FullAdder( a(0), b(0), carryIn );
for( i <- 1 until 8 ){
fullAdders(i) = new FullAdder(a(i),b(i), fullAdders(i-1).getCarryOut);
}
fullAdders
}

This method is responsible for creating all the full adders, and chaining them together. The createOutputs method was only responsible for getting all the outputs off of the full adders created here.

Similar to the last method, this method uses an imperative style, creating an array, populating it, and finally returning it. It's quite a bit trickier though because of the chaining. You can't simply use the map function because there's no context in map. Here, each new full adder needs to know about the preceding full adder. This guy is going to be a bear to explain, but let me just go ahead and dump the code on you:

val (fullAdders, carryOut) =
(as zip bs).foldLeft((List[FullAdder](),carryIn)){
case ((current, carry), (a, b)) =>
val adder = new FullAdder(a, b, carry)
(current ::: List(adder),adder.carryOut)
}


With James Iry's help, we've made this code about as simple as possible. With the first pass, I wasn't sure if the functional code was more readable than the imperative, but after he helped me clean it up, I'm positive. Now, if you aren't familiar with some of the concepts contained in that code, you might be thinking, "What are you $%^&ing nuts?" But, I'm convinced that after you get used to reading this, it's so much easier to read, and so much less error prone, and so much more natural, that you'll never go back. Honestly, after leaving this project alone for almost a year and coming back to it and finding the imperative code, I almost threw up in my mouth a little.

Okay, now I'll try to explain these concepts, and most likely fail miserably.


  1. First, and simplest, is zip. This one is pretty easy. Taken right from the Scaladoc - zip:

    "Returns a list formed from this list and the specified list 'that' by associating each element of the former with the element at the same position in the latter. If one of the two lists is longer than the other, its remaining elements are ignored."


    I think a few examples will explain perfectly.

    Given a=List(1,2,3) and b=List(a,b,c) then a.zip(b) will return List((1,a), (2,b), (3,c)).
    Given a=List(1,2,3) and b=List(a,b,c,x,y,x) then a.zip(b) will return List((1,a), (2,b), (3,c)), as the remaining elements in b are ignored.

    The code in the cpu simulator zips as and bs, which associates the appropriate input bits together. Take a quick look back at the picture to see that this returns ((a0,b0),(a1,b1),(a2,b2),(a3,b3),(a4,b4),(a5,b5),(a6,b6),(a7,b7)).

  2. Next is foldLeft, which is a bit more complicated. Once again, from the Scaladoc - foldLeft:

    Combines the elements of this list together using the binary function f, from left to right, and starting with the value z.



    This one I've written up separately, because it was so long. You can find it at http://jackcoughonsoftware.blogspot.com/2008/11/foldleft-in-scala-little-schemer-style.html


  3. Next is pattern matching, but I have to go to bed again! At least I've made some progress :)



Revised Code


Here is the finshed product, which no longer uses 8, but instead creates adder chains of N, depending on the length of the inputs. Overall, I think it's a vast improvement over the original.


class AdderChain(as: Number, bs: Number, carryIn: PowerSource) {

if( as.size != bs.size ) error("numbers must be the same size")

val (fullAdders, carryOut) =
(as zip bs).foldLeft((List[FullAdder](),carryIn)){
case ((current, carry), (a, b)) =>
val adder = new FullAdder(a, b, carry)
(current ::: List(adder),adder.carryOut)
}

val output = new Number(fullAdders.map(fa => fa.sumOut))
}

4 comments:

  1. I haven't completely tested this, but this should illustrate some ideas even if the details are a little off. Here I use tuples and pattern matching a lot more. And in my fold, I start with an empty list plus the carryIn as my base then just walk through the list. That removes the need for treating the head of the zip as special.

    http://pastebin.com/m2fe34037

    ReplyDelete
  2. Thanks James! This is absolutely the right approach. I've updated the code, ran the tests, and everything is great. Unfortunately though, intellij is puking on the new code...I'm not sure why.

    I'm going to update the post tonight to use the new code. Here it is:


    class AdderChain(as: Number, bs: Number, carryIn: PowerSource) {

    if( as.numberOfBits != bs.numberOfBits ) error("numbers must both be the same size")

    private val (fullAdders, carryOut) =
    (as zip bs).foldLeft((List[FullAdder](),carryIn)){
    case ((current, carry), (a, b)) =>
    val adder = new FullAdder(a, b, carry)
    (current ::: List(adder),adder.getCarryOut)
    }

    private val output = new Number(fullAdders.map(fa => fa.getSumOut))
    def getOutput: Number = output
    def getCarryOut: PowerSource = carryOut
    }

    ReplyDelete
  3. Loving the posts Jack. I am from a Java background and watching your posts is both entertaining (as it occasionally matches my own embarrassing code) and very interesting to see if you make the same decisions (or different) from me.

    Thanks a ton

    ReplyDelete
  4. Jesse, thanks a lot! I really appreciate the vote of confidence. Glad to see you're enjoying it. Now that I'm back from the dead, I'll continue.

    ReplyDelete