Some times you might need to iterate through multiple lists. This happens, for example, when implementing a tree traversal algorithm. Thanks to tail recursion and Kotlin, a trivial implementation can be as follows:
tailrec fun countAll(nodes: List<Node>, acc: Int): Int {
return when {
nodes.isEmpty() -> acc
else -> countAll(nodes.first().children + nodes.subList(1, nodes.size), acc + 1)
}
}
You start by invoking countAll
on the root node (e.g. countAll(listOf(root), 1)
) and the function will accumulate all the children iteratively until all of them are traversed and counted.
The problem with this approach is that it spends most of its time producing the result of the concatenation by copying all the elements of the operands into a new list.
As you can see from the flame-graph, the majority of the time is spent on constructing the ArrayList
instance and invoking addAll
.
The reason we are concatenating lists here is only to be able to keep iterating through our collection of nodes as more items get added.
Let's make it better
An alternative approach we could use in order to avoid the overhead of concatenating the lists is by taking advantage of an iterator that can cross the list boundaries.
If we had one, we could rewrite our countAll
function as follows:
tailrec fun countAll(nodes: Iterator<Node>, acc: Int): Int {
return when {
!nodes.hasNext() -> acc
else -> {
val node = nodes.next()
countAll(nodes + node.children.iterator(), acc + 1)
}
}
}
🏃♂️ Time to code!
In our ideal implementation, the plus operator will take care of storing the iterator for the children list without preemptively copying all the items in a new list.
Let's take a look at the iterator implementation:
class ConcatIterator<T>(iterator: Iterator<T>) : Iterator<T> {
private val store = ArrayDeque<Iterator<T>>()
init {
if (iterator.hasNext())
store.add(iterator)
}
override fun hasNext(): Boolean = when {
store.isEmpty() -> false
else -> store.first.hasNext()
}
override fun next(): T {
val t = store.first.next()
if (!store.first.hasNext())
store.removeFirst()
return t
}
operator fun plus(iterator: Iterator<T>): ConcatIterator<T> {
if (iterator.hasNext())
store.add(iterator)
return this
}
}
As the iterators get concatenated they get enqueued in the store
. In this case, the store uses an ArrayDeque which is a lightweight non-concurrent implementation of a Queue that performs the majority of its operations in amortized constant time.
The last touch
The final addition to make working with the iterator a little more comfortable is achieved by implementing the plus
operator on Iterator
.
operator fun <T> Iterator<T>.plus(iterator: Iterator<T>): ConcatIterator<T> =
when {
this is ConcatIterator<T> -> this.plus(iterator)
iterator is ConcatIterator<T> -> iterator.plus(this)
else -> ConcatIterator(this).plus(iterator)
}
This extension function helps us easily concatenate two existing iterators into a new one by producing a ConcatIterator
. The nice thing about this function is that it reuses an existing ConcatIterator
instance if available among the two instances.
Let the numbers speak
Now that we have implemented our alternative version of the countAll
function let's see how it performs.
I've tested my assumptions using a little Kotlin playground that I've created to experiment with trees. You can find it here.
The following results come from testing the two implementations against a tree with 65201277 nodes.
Total count (countAll without ConcatIterator): 65201277 nodes (9227 ms)
Total count (countAll with ConcatIterator): 65201277 nodes (1288 ms)
As you can see, the ConcatIterator
version is almost 10 times faster. We're not incurring in the overhead of concatenating lists anymore so the majority of the computation is spent performing the counting.
Conclusion
I hope you enjoyed. Let me know how you approach this kind of problem and if there are better ways of achieving this in Kotlin.
Cheers!