Let's Implement Overloading/Multiple-Dispatch

unworthyEnzyme - Aug 20 - - Dev Community

A couple years ago, I came across a language called Julia. It's multiple dispatch feature was very interesting; I wanted to know how it worked under the hood, but I didn't have the knowledge to do that yet. So here I am, finally giving it a try. Now that I have an implementation, I realized there is nothing tying this algorithm to runtime dispatch; I think it could be used in a language with static dispatch as well. If you're interested in learning about multiple dispatch, I left some links at the end of the post. So I guess this post is just about selecting the most specific function for a given set of arguments in a language with subtyping. Ok, let's get started.

Describing the Problem

I'm assuming you're familiar with the concept of subtyping and function overloading. For a brief explanation, subtyping is a way to relate types to each other. For example, if we have a type A and a type B and B is a subtype of A, we can say that B is an Aso we can use B anywhere we need an A. And function overloading is a way to define multiple functions with the same name but different arguments. The function that is called is the one that matches the arguments the "best".

Let's say we have a subtyping hierarchy like this:

abstract type Number
struct Complex <: Number
struct Real <: Complex
Enter fullscreen mode Exit fullscreen mode

And function like this:

function foo(x::Number, y::Number)
    println("Number, Number")
end

function foo(x::Complex, y::Complex)
    println("Complex, Complex")
end

function foo(x::Real, y::Real)
    println("Real, Real")
end
Enter fullscreen mode Exit fullscreen mode

We want to select the most specific function for the given arguments. For example, if we call it with foo(Real(), Real()) we call the last definition; if we call it with foo(Complex(), Number()) we call the first definition because we can't pass a Number where we need a Complex. The thing is, when we called foo(Real(), Real()) it would be totally fine to call the second definition; the code would work just fine. So we not only want to find a method that conforms to the arguments but also a way to rank them.

Modeling Subtyping

Let's start by modeling subtyping first. I want something like this:

any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)
string = Type.new("String", any)

puts real.is?(number) # true
puts real.is?(complex) # true
puts real.is?(real) # true
puts real.is?(string) # false
puts real.is?(any) # true
puts string.is?(any) # true
Enter fullscreen mode Exit fullscreen mode

All types are subtype of Any which has no supertype. This type is also called the top type. You'll notice that a type is a subtype of itself. Why? Well, if we return to the definition of subtyping, a type is a subtype of another if it can be used in place of the other. So if we have a function that takes a Real we can pass a Real to it. So Real is a subtype of Real. And Any is also a supertype of all types, which includes itself, which I think is quite nice.

We could implement this API like this:

class Type
  attr_reader :name, :supertype

  def initialize(name, supertype = nil)
    @name = name
    @supertype = supertype
  end

  def is?(type)
    return true if type == self
    return false if @supertype.nil?
    @supertype.is?(type)
  end

  def ==(type)
    @name == type.name
  end
end
Enter fullscreen mode Exit fullscreen mode

We already talked about every type being a subtype of itself, so first return statements do that check. The second one accounts for the case where we call any.is?(real). The third return statement recursively traverses the subtyping hierarchy to find if type is in there.
I also overrode the == method to have value equality.

Modeling Function Signatures

A signature is just a list of types.

class Signature
  attr_reader :types

  def initialize(types)
    @types = types
  end

  def ==(signature)
    @types == signature.types
  end
end
Enter fullscreen mode Exit fullscreen mode

We need a way to know if it is legal to call a signature with a given list of argument types. Let's look at a simple case:

function f(x::Real) end
Enter fullscreen mode Exit fullscreen mode

We obviously shouldn't be able to call this function with f(Complex()) because Complex is not a subtype of Real.
This generalizes to multiple arguments as well. We should be able to call f(Real(), Real()) but not f(Complex(), Real()) which I mentioned before. There is nothing that makes the first argument special. So the gist of it is, we should be able to call a function with a list of types (t1, t2, ..., tn) where the signature of the function is (s1, s2, ..., sn) if every t is a subtype of the corresponding type s.
The implementation is quite simple:

class Signature
  # Other stuff...

  def conforms?(signature)
    # for a signature to conform to this one:
    # 1. it must have the same number of types
    return false if signature.types.length != @types.length
    # 2. each type must be a subtype of the corresponding type in this signature
    @types.zip(signature.types).all? { |a, b| a.is?(b) }
  end
end
Enter fullscreen mode Exit fullscreen mode

Ranking Conforming Signatures

Now we come to the meat of the problem. How do we select the most specific function for a given list of argument types? We need a way to rank them. Let's look at the simplest case again.

function f(x::Number) end
function f(x::Complex) end
f(Real())
Enter fullscreen mode Exit fullscreen mode

If I asked you which function should be called, you would say the second one, right? Why? Well, when we consider the subtype hierarchy Real <: Complex <: Number <: Any, the closest to type Real is Complex so you choose that. This is the main idea behind the ranking algorithm. We need to find the closest type to the given type in the signature.
We need a way to get the distance from the given type:

class Type
  # Other stuff...

  def distance(type)
    # We only want to consider the chain not the tree so we don't consider types from different branches. e.g. Real and String
    raise "Not a subtype" unless is?(type)
    # If we are at the same type, we are 0 distance away
    return 0 if self == type
    # We are whatever distance `type` away from our supertype + 1.
    # Example:
    #   real.distance(any) = 1 + complex.distance(any)
    #                      = 1 + 1 + number.distance(any)
    #                      = 1 + 1 + 1 + any.distance(any)
    #                      = 1 + 1 + 1 + 0
    #                      = 3
    1 + @supertype.distance(type)
  end
end
Enter fullscreen mode Exit fullscreen mode

We can already rank functions with a single argument with this. One way to generalize this to multiple arguments is to find the distance between each corresponding type between signature type and argument type and sum them up.

class Signature
  # Other stuff...

  def distance(signature)
    @types.zip(signature.types).sum { |a, b| a.distance(b) }
  end
end
Enter fullscreen mode Exit fullscreen mode

If you're a bit interested in math, this is a metric. A signature is like an n-dimensional point, and the distance between two points is the sum of the distances between each corresponding scalar(type) in the point. Honestly, you can even use Euclidean distance if you want, but I think this metric suffices. What we're using here is called Manhattan distance by the way.

If we put this to work:

any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)

f1 = Signature.new([number, number])
f2 = Signature.new([complex, complex])
f3 = Signature.new([real, real])

call_signature = Signature.new([real, real])
puts [f1, f2, f3].min_by { |f| call_signature.distance(f) } == f3 # true
Enter fullscreen mode Exit fullscreen mode

A function isn't just a signature though, it also has a name.

class Function
  attr_accessor :name, :signature

  def initialize(name, signature)
    @name = name
    @signature = signature
  end

  def to_s
    "#{@name}#{@signature}"
  end

  def ==(other)
    @name == other.name && @signature == other.signature
  end
end
Enter fullscreen mode Exit fullscreen mode

And we'll have a table of functions that contains all the definitions and gives the most specific one for a given call.

class FunctionTable
  attr_accessor :functions

  def initialize()
    @functions = []
  end

  def add(function)
    raise "Function already exists" if @functions.include?(function)
    @functions << function
  end

  def find(function)
    # find all the signatures that conform to the given signature.
    candidates = @functions.select { |m| function.signature.conforms?(m.signature) && m.name == function.name }
    # sort them by distance from closest to furthest.
    sorted_by_distance = candidates.sort_by { |m| function.signature.distance(m.signature) }

    # find the closest one.
    # There may be more than one with the same distance, so we find all of them.
    distances = sorted_by_distance.map { |m| function.signature.distance(m.signature) }
    min_distance = distances.min
    closest_functions = sorted_by_distance.select { |m| function.signature.distance(m.signature) == min_distance }
    raise "Ambiguous function call between #{closest_functions}" if closest_functions.length > 1
    closest_functions.first
  end
end
Enter fullscreen mode Exit fullscreen mode

The most interesting method is find.

  1. It finds all the functions with the same name and conforming signature.
  2. Sorts them by distance.
  3. Finds the closest one.
  4. If there are more than one with the same distance, it raises an error. I choose to raise an error but one thing you could also do is return the return the second closest method(I wonder if there is a metric where two different point can't have the same distance from a different third point).

If we put it all together:

any = Type.new("Any")
number = Type.new("Number", any)
complex = Type.new("Complex", number)
real = Type.new("Real", complex)

f1 = Function.new("f", Signature.new([number, number]))
f2 = Function.new("f", Signature.new([complex, complex]))
f3 = Function.new("f", Signature.new([real, real]))

table = FunctionTable.new
table.add(f1)
table.add(f2)
table.add(f3)

call_signature = Signature.new([real, real])
puts table.find(Function.new("f", call_signature)) == f3 # true
Enter fullscreen mode Exit fullscreen mode

Full Code

Further Reading

. .
Terabox Video Player