February 19, 2026

Symbolic Math

When I first used Mathematica (and later sympy), it seemed like magic. Instantly solving math problems? Amazing. How? Well, let's try to do something relatively simple: calculate a derivative. Given a mathematical expression, we want to symbolically calculate a derivative. I say 'symbolically' because we are going to be manipulating symbols, not doing any numerical methods.

Consider an expression like '3*x + 2'. We need a way to covert this from a string into some other data structure that accounts for the variable x, constants 2 and 3, and the addition and multiplication operations. We then can define some rules for how to modify this structure to take a derivative. Similar to interpreting a programming language, we can parse a mathematical expression into an abstract syntax tree. Our simple expression becomes something like:


      +
     / \
    *   2
   / \
  3   x

Tokenize

So first, we need to convert an input string into a list of tokens. We can define an Enum for each type of token we want to consider:

class TokenType(Enum):
    LEFT_PAREN = "("
    RIGHT_PAREN = ")"
    PLUS = "+"
    MINUS = "-"
    MULT = "*"
    DIV = "/"
    EXP = "^"
    CONST = "const"
    VAR = "var"
    LN = "ln"

The tokens for parens and operators are simple enough. But we also need to pick out numbers and variables. We can then write a tokenize function that will move through an input string and recognize each of these tokens. Here, I am just going to assume that pure numeric strings are constant integers, and any string of alphabetic characters is a variable name. With a bit of extra work, we could also look for decimal numbers, or impose other limitations on variable names. We also just ignore whitespace.

def tokenize(stream):
    if not stream: return []
    if stream[0] == " ": return tokenize(stream[1:])
    if stream[0] == "(": return [(TokenType.LEFT_PAREN, '(')] + tokenize(stream[1:])
    elif stream[0] == ")": return [(TokenType.RIGHT_PAREN, ')')] + tokenize(stream[1:])
    elif stream[0] == "+": return [(TokenType.PLUS, '+')] + tokenize(stream[1:])
    elif stream[0] == "-": return [(TokenType.MINUS, '-')] + tokenize(stream[1:])
    elif stream[0] == "*": return [(TokenType.MULT, '*')] + tokenize(stream[1:])
    elif stream[0] == "/": return [(TokenType.DIV, '/')] + tokenize(stream[1:])
    elif stream[0] == "^": return [(TokenType.EXP, '^')] + tokenize(stream[1:])
    elif stream[0] in '0123456789':
        i = 0
        while i < len(stream) and stream[i] in '0123456789': i += 1
        return [(TokenType.CONST, stream[:i])] + tokenize(stream[i:])
    elif stream[0].isalpha():
        i = 0
        while i < len(stream) and stream[i].isalpha(): i += 1
        return [(TokenType.VAR, stream[:i])] + tokenize(stream[i:])
    else:
        raise ValueError("Unknown token: %s", stream[:10])

Ok, so now we have a list of tokens. How do we convert this into a tree structure? Well, we need to define some different types of nodes. Most of this is based on the fantastic Crafting Interpreters, so if you want a thorough, well-written explanation of recursive descent parsing and Backus–Naur, and context-free grammar, go read Chapter 6. But my caveman interpretation is as follows.

Parsing

We define a hierarchy of objects. At the very top is Expr: a general expression that can include everything. At the very bottom: constants and variables, or anything grouped in parens. In between, we have terms that can be arguments to our operators, exponent, add/subtact, multiply/divide. The order in which we stack these operators is what determines precedence. That is, PEMDAS.

Expr -> term term -> factor ((+ | -) factor) * factor -> unary ( (/|*) unary) * expo -> term ( ^ term) * unary -> - unary | primary primary -> const | var | (expr)

I define a few classes to serve as different types of AST nodes. Binary represents a binary operation (add, multiply, exponent, etc.) with an operator, and left and right terms. Unary represents a unary operation (i.e., negative). Const and Var are leaf nodes representing either a variable or a constant.

class Binary:
    def __init__(self, left, operator, right):
        self.left = left
        self.right = right
        self.operator = operator
    def __str__(self):
        return f"({str(self.left)} {self.operator.value} {str(self.right)})"

class Unary:
    def __init__(self, right, operator):
        self.right = right
        self.operator = operator
    def __str__(self):
        return f"{self.operator.value}{self.right}"

class Const:
    def __init__(self, value):
        self.value = value
    def __str__(self):
        return f"{self.value}"
class Var:
    def __init__(self, name):
        self.name = name
    def __str__(self):
        return f"{self.name}"

We create a parser class to hold some state: our list of tokens, and the current index. As well as a match function to see if the current token matches some set of types we will be looking for.

class Parser:
    def __init__(self, tokens):
        self.tokens = tokens
        self.current = 0

    def match(self, *types):
        if self.current >= len(self.tokens): return False
        for tok in types:
            if self.tokens[self.current][0] == tok:
                self.current += 1
                return True
        return False

Now, the tricky part. We create a series of functions for each of our expression types. We start at the top with expr(). This is easy and just calls term().

def expression(self):
    return self.term()

Now term() is looking for addition or subtraction. So we will check if the current token matches one of these operators. But first, we move further down the grammar list (or up the precendence order) and check for higher precendence operations. The match function will see if the current token matches any of the provided tokens, and if so, advance the current pointer. So if we find a + or - operator, we now look for the right operand and combine it with the left operand we got from calling factor() and creeping up the precendence order.

def term(self):
    expr = self.factor()
    while self.match(TokenType.PLUS, TokenType.MINUS):
        operator = self.previous()[0]
        right = self.factor()
        expr = Binary(expr, operator, right)
    return expr

We have similar functions for factor(), expo(), unary(), and primary(). We hit bottom at primary(), where we check for a constant, variable, or parentheses:

def primary(self):
    if self.match(TokenType.CONST):
        return Const(self.previous()[1])
    elif self.match(TokenType.VAR):
        return Var(self.previous()[1])
    elif self.match(TokenType.LEFT_PAREN):
        expr = self.expression()
        self.consume(TokenType.RIGHT_PAREN)
        return expr

So for the simple list of tokens we got from '3 * x + 2', upon entering the expression() function, we call into term(), which calls factor(), which calls all the way down to primary and matches our Const(3). This bubbles back up into factor(), where it finds the multiplication operator and calls filter back down to primary to get Var(x) as the right operand to TokenType.MULT. This Binary expression then is used as the left operand to the '+' operator in term(), which calls back down to priamry to find the last Const(2). Hopefully that makes sense. It can be hard to follow the flow of operations with these nested function calls.

Calculating a Derivative

This gives us our AST that represents some expression. Now we can implement a derivative function. This function will take an AST node and recursively apply differentiation rules. First, the easy ones: constants and variables. The derivative of a constant is zero. The derivative of Var(x) is 1 if we are taking the derivative with respect to x, or 0 otherwise:

def derivative(expr, var_name):
    if type(expr) is Const:
        return Const(0)
    elif type(expr) is Var:
        if expr.name == var_name:
            return Const(1)
        else:
            return Const(0)

The real meat of this function comes with handling binary expressions. Here, we need to use sum, product, and division rules:

if type(expr) is Binary:
    if expr.operator == TokenType.MINUS or expr.operator == TokenType.PLUS:
        #print("sum rule")
        return Binary(derivative(expr.left, var_name), expr.operator, derivative(expr.right, var_name))
    elif expr.operator == TokenType.MULT:
        #print("product rule")
        first_term = Binary(expr.left, TokenType.MULT,derivative(expr.right, var_name))
        second_term = Binary(derivative(expr.left, var_name), TokenType.MULT,expr.right)
        return Binary(first_term, TokenType.PLUS, second_term)
    elif expr.operator == TokenType.DIV:
        #print("low d high minus high d low over low low")
        first = Binary(expr.right, TokenType.MULT, derivative(expr.left, var_name)) # low d_high
        second = Binary(exp.left, TokenType.MULT, derivative(expr.right, var_name)) # high d_low
        numerator = Binary(first, TokenType.MINUS, second)
        denominator = Binary(expr.right, TokenType.MULT, expr.right)
        return Binary(numerator, TokenType.DIV, denominator)

For the sum rule, we return a new Binary operator adding together the derivatives of both operands. For multiplication, we have to return the addition of two multiplication terms. And for division, we have a numerator and denominator following the familiar rule.

For exponents, the most general rule is a little more complicated. The derivative of f(x)^g(x) ends up being \(f^{g}(g'\ln{f} + f'\frac{g}{f})\). You can show for yourself that for simple polynomials, this reduces to the well known power rule such that \(\frac{d}{dx}x^{n} = nx^{n-1}\). But since we are dealing just AST nodes, I wanted this to be as general as possible:

elif expr.operator == TokenType.EXP:
    g_prime_ln_f = Binary(derivative(expr.right, var_name), TokenType.MULT, Unary(expr.left, TokenType.LN))
    f_prime_g = Binary(derivative(expr.left, var_name), TokenType.MULT, expr.right)
    f_prime_g_over_f = Binary(f_prime_g, TokenType.DIV, expr.left)
    combined = Binary(g_prime_ln_f, TokenType.PLUS, f_prime_g_over_f)
    return Binary(expr, TokenType.MULT, combined)

Now let's try putting it all together and see what we get:

t = tokenize("x^2 + 1")
p = Parser(t)
AST = p.parse()
print(AST)   # prints ((x ^ 2) + 1)
deriv = derivative(AST, 'x')
print(deriv) # prints (((x ^ 2) * ((0 * lnx) + ((1 * 2) / x))) + 0)

Ok, that's technically correct (the best kind of correct), but not the simplest representation. We're carrying around a lot of zeroes we can get rid of. Completely simplifying any expression is probably a huge problem all on its own. For this little project, I just want to deal with adding zero to something, multiplying by zero, or multiplying by 1. Let's try a simplify function to handle these cases:

def simplify(ast):
    if type(ast) is Binary:
        if ast.operator == TokenType.PLUS:
            if isconst0(ast.left):
                return simplify(ast.right)
            elif isconst0(ast.right):
                return simplify(ast.left)
        elif ast.operator == TokenType.MULT:
            if isconst0(ast.left) or isconst0(ast.right):
                return Const(0)
            if isconst1(ast.left):
                return simplify(ast.right)
            elif isconst1(ast.right):
                return simplify(ast.left)
        return Binary(simplify(ast.left), ast.operator, simplify(ast.right))
    return ast

I chose to define isconst0 and isconst1 functions here, which just check if the type is Const and value is 0 or 1. You might have tried something like if ast.left == Const(0), but since we didn't really define the equality operator for our Const class, python is just looking to see if these are exactly the same object (which they are not), not if they are both constants with the same value.

After a few calls to this simplify function, we can get our final result down to: ((x ^ 2) * (2 / x)). Not perfect, but good enough for a blog post, which is already getting kind of long. I have a new appreciation for how complex symbolic computation can get and still think Mathematica and sympy are kind of magical. Full code for this post on github.