A Pythonic Trie

A Pythonic Trie

A trie is a deceptively useful data structure. It's not something you'd really think of by yourself--but becomes brilliant and almost a no-brainer once you see it in action. Here's the Trie that we'll be designing in the rest of this article:

from operator import getitem
from functools import reduce

class Trie(defaultdict):

    __repr__ = dict.__repr__

    def __init__(self):
        super().__init__(Trie)
        self.count = 0
        self.is_word = False

    def add(self, word):
        for c in word:
            (self := self[c]).count += 1
        self.is_word = True

    def exists(self, word):
        return reduce(getitem, word, self).is_word

    def startswith(self, prefix):
        return reduce(getitem, prefix, self).count > 0

Now let's get started

Introductions to prefix trees, or Tries

Generally, a trie is a k-ary search tree where we store strings or prefixes as a path from the root to the leaf in the tree.

That is to say, each node in this path will spell out the string altogether!

Here's an example from HackerEarth:

Example of a trie

As you can see if you trace all the possible routes from the root, you'd end up getting all of the strings on the right.

The nice thing is given something like the English alphabet and single case, we'll never have more than 26 as the branching factor for any node, since we only have 26 letters in the alphabet.

Large Inputs

So as it is in the diagram, it's a bit difficult to utilize. Usually, in our implementation we want to know that a leaf node marks the end of a word.

Here are some ways to do so:

# If we've made a Trie object
self.is_word = True
# A technique you often see for simple dict-implemented Tries
t[c] = #

And since we might have a lot of overlap when storing words, we'd probably want to store the count of each node or letter in the Trie.

This is nice for large inputs where you might want to see how many of a certain prefix or string there might be.

We know that to get to a node you need to have the letters it descended from, so we can prove that accessing the count of the final node in a search is an O(1) constant time lookup.

Pretty useful when you're trying to filter through a lot of strings AND you don't have to store non unique nodes or letters!

Implementation

There's a lot more to tries, for example modifications that might be useful for certain classes of problems, or storing things like binary numbers instead (think 0100 0101 etc).

But the purpose of this article is really to introduce a nice template for Tries that really utilize the syntactic sugar in Python.

Here's something we might do in Python, that works well and is clear:

class Node:
    def __init__(self):
        self.is_word = False
        self.children = {}

class Trie():

    def __init__(self):
        self.root = Node()

    def insert(self, word):
        '''
        Iterate through each character and cascade through the
        trie.
        '''
        curr = self.root
        for char in word:
            if char not in curr.children:
                curr.children[char] = Node() # Make a new Trie node
            curr = current.children[char] # step into the Trie node we just made.
        curr.is_word = True

This is a good implementation, but we see that generating children requires us to make a Trie node and subsquently all of it's children Trie nodes--a recursive operation.

And that's exactly where defaultdict comes in.

defaultdict

A Python container in the standard library's collection module that calls a default factory when values are missing. At it's core it lets you get around having to check if something is in the dictionary first, or avoid writing

dict.get(key, default)

A simple example is counting.

We could do this:

student_grades = [100, 100, 90, 80, 10, 20]
frequency = dict()

for student_grade in student_grades:
    if student_grade not in frequency:
        frequency[student_grade] = 1
    else:
        frequency[student_grade] += 1

Or we could do this:

student_grades = [100, 100, 90, 80, 10, 20]
frequency = defaultdict(int) # Returns the default value for an integer, which is 0.

for grade in student_grades:
    frequency[grade] += 1

As you can see, it lets us avoid writing the same thing over and over and allows us to make some assumptions when writing our code.

Connecting back to tries

Since we know that we want our Trie object to have it's children automatically become Tries, then we can inherit the default factory properties using defaultdict when we design our class.


class Trie(defaultdict):

    def __init__(self):
        super().__init__(Trie)
        self.is_word = False
        self.count = 0

We've gone ahead and called the init method of the class we're inheriting from. If you recall, passing in int gets us 0 as the default value.

While it looks confusing that we're calling the class we've just made, all we're really saying is give us a new Trie object with the self.is_word and self.count set to our default values.

Now we need to add a method to add words or strings.


def add(self, word):
    for char in word:
        (self := self[c]).count += 1
    self.is_word = True

There are a few things happening here in the for loop, but it's quite simple, if you take it step by step.

  • (self := self[c])
    • This is an assignment expression using the Walrus operator that's been in Python since 3.8. At it's core, it can be useful to assign a variable after an expression is evaluated for further use. This means its great for situations where you have intermediate steps and you might want to store the results of those steps as you generate them.
    • In this case we add the current character to our Trie (which is inheriting from defaultdict), meaning we can index it by a key to find the value just like a regular dictionary.
    • And then we just update self to the child object we just made.
    • Lastly, we increment it's count by one.
  • self.is_word = True
    • Once we get to the last character in the word, it's a leaf node. So we just set it's is_word variable to True.

Getting functional

Now that we've made it easy to add words and generate child nodes.

We can make our search method.

Generally, we just do our standard DFS.

def search(self, word):
    t = self
    for c in word:
        if c in t:
            t = t[c]
        else:
            return False
    return self.is_word

The above is an iterative version that simply requires us to update where we're at in the Trie until we've exhausted all characters.

But we can get a little more exciting with it.

Using the reduce function that has now been relegated to the functools module in the standard library, we can capture the above in a single line.

If we recognize that we're always taking the next element of word and cascading to the next corresponding Trie node--we can represent this as a cumulative operation.

As a quick reminder on reduce, it takes three potential arguments:

  1. The function to apply
  2. The iterable to reduce.
  3. An optional default value.

Here's an example from the Python docs

reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) is the same as ((((1+2)+3)+4)+5)

Now let's go ahead and define those 3 things for our search operation.

  1. The function to apply

    1. This might seem unclear, but what we're doing at each step of our search operation is getting the child node that corresponds to trie[character]. The function that we'll use to do so is operator.getitem.
    2. getitem for brevity's sake is basically our [] operation when accessing keys in dictionaries.
  2. The iterable to reduce

    1. Do we use word or self?
      1. Remember that the method we're applying is getitem meaning our second argument should correspond to the iterable containing the specific keys we'll need--or the word that we're looking for.
  3. An optional default value

    1. In our case, it's not optional, and since we've already used the function and word, this is going to be self.
    2. For further clarity, this is the initializer that we call getitem ON.
    3. This makes sense, as this will also return the Trie node once we've gotten an item, and continue to reduce it until we hit a leaf node.

Final Design

Altogether we have:

from operator import getitem
from functools import reduce

class Trie(defaultdict):

    __repr__ = dict.__repr__

    def __init__(self):
        super().__init__(Trie)
        self.count = 0
        self.is_word = False

    def add(self, word):
        for c in word:
            (self := self[c]).count += 1
        self.is_word = True

    def exists(self, word):
        return reduce(getitem, word, self).is_word

    def startswith(self, prefix):
        return reduce(getitem, prefix, self).count > 0

And that's it! We now have a basic Trie that inherits from defaultdict to simplify our child node creation process, we use the walrus operator to simplify reassignment when adding characters, and we use functools.reduce alongside the operator.getitem to capture the search process in a single line.

Optional Debugging Tip

You'll notice that I added __repr__ = dict.__repr__, this is a quick line you can add to borrow the representation of dictionaries to make it quite a bit easier to view your Trie.

It's certainly worth creating a custom repr function but we'll leave that for another time.

startswith

To determine if a prefix is present rather than a word, we've once again used reduce but passed the prefix as the iterable and return whether the count of the last element in the prefix is greater than zero.

The transformation process

Below is a Snappify presentation that let's you copy paste and see the transition we took from a single dictionary to the final Trie:

Next

writing a realtime translation app