Jekyll2022-10-23T16:40:30+00:00http://sassafras13.github.io/feed.xmlEmma BenjaminsonMechanical Engineering Graduate StudentKadane’s Algorithm2022-09-04T00:00:00+00:002022-09-04T00:00:00+00:00http://sassafras13.github.io/KadanesAlgo<p>Today I was working through a dynamic programming problem in Leetcode when I encountered a reference to an algorithm that I had not heard of before: <strong>Kadane’s Algorithm</strong>. I wanted to write a quick post to explain this algorithm in more detail, and hopefully this will help me solve more Leetcode problems, too. I’ll start by describing the problem that Kadane’s Algorithm solves, and some of the contexts where it appears in practice. Then I will describe a number of ways to solve it, including Kadane’s Algorithm.</p>
<h2 id="the-problem">The Problem</h2>
<p>Kadane’s Algorithm solves what is known as the <strong>maximum subarray problem</strong>, where we are given an array of numbers, and we want to find the continuous subarray with the largest sum. The numbers in the array can be positive, or negative, or 0. As an example, if we are given the array [1]:</p>
<pre><code class="language-python3">[−2, 1, −3, 4, −1, 2, 1, −5, 4]
</code></pre>
<p>Then the solution is <code class="language-plaintext highlighter-rouge">[4, -1, 2, 1] = 6</code>.</p>
<p>This might seem like an arbitrary, esoteric problem (which is how I feel about most of Leetcode’s questions, if I’m being honest) but this problem actually turns out to have many real world applications. The problem was originally formulated by Ulf Grenander as a way of computing the <a href="https://sassafras13.github.io/InferencePGMs/">maximum likelihood estimate</a> for patterns in images. He wanted to find a 2D array with the maximum sum within an image, and formulated this 1D version as a way of looking for the underlying structure in the problem. Grenander found an algorithm that ran in $\mathcal{O}(n^2)$ time (which is $n$ times slower than the fastest possible approach, which would only need to scan the array once and would therefore, theoretically, run in $\mathcal{O}(n)$ time). Another researcher, Michael Shamos, found a $\mathcal{O}(n \log (n))$ solution, and talked about this problem and his search for faster algorithms during a seminar he presented at Carnegie Mellon University. Jay Kanade, a professor in the Department of Statistics, heard the seminar and apparently developed a solution in $\mathcal{O}(n)$ time in under a minute [1, 2].</p>
<p>Apocryphal stories aside, solving the maximum subarray problem is useful to more than just finding patterns in images. For example, this kind of problem appears in genomic sequencing work where we are often trying to find specific regions in protein sequences that have certain properties, or we are looking for GC-rich regions in DNA strands. And it turns out that there is more than one way to solve the maximum subarray problem. We can solve it using divide-and-conquer strategies (like Shamos), using dynamic programming, or brute force [1].</p>
<h2 id="kadanes-algorithm">Kadane’s Algorithm</h2>
<p>Before presenting Kadane’s Algorithm, it might help to consider some properties of this problem to become more familiar with it.*1 Specifically [1]:</p>
<ol>
<li>
<p>If the input array has all positive numbers, then the solution is easy - it’s just the sum of all the numbers in the array.</p>
</li>
<li>
<p>If the input array has all negative numbers, then we want to find the maximal single entry in the array and just return that value.</p>
</li>
<li>
<p>It is possible for there to be more than 1 subarrays with the same maximum value.</p>
</li>
</ol>
<p>Kadane’s Algorithm is an elegant approach to solving this problem, which works given all of these properties of the problem statement. We will consider here the formulation of the problem where empty subarrays are not allowed. In this case, we can write Kadane’s Algorithm as follows [1]:</p>
<pre><code class="language-python3">def max_subarray(numbers):
best_sum = float('-inf')
current_sum = 0
for x in numbers:
current_sum = max(x, current_sum + x)
best_sum = max(best_sum, current_sum)
return best_sum
</code></pre>
<p>Kadane’s Algorithm is employing a simple form of dynamic programming, because it is building up a solution from optimal solutions to sub-problems. The sub-problem is finding the maximum sum for the subarray that we have seen at the previous step. That is, the <code class="language-plaintext highlighter-rouge">current_sum</code> is the optimal solution to the sub-problem of finding the max sum of the subarray that represents all the numbers we have seen up to now [1].</p>
<p>Rohit Singhal gave a very nice diagrammatic explanation of the dynamic programming aspect of this algorithm, as shown in Figure 1 [3]. If we consider the left hand side of Figure 1, we can see that we can find the optimal solution to the maximum subarray at <code class="language-plaintext highlighter-rouge">nums[4]</code> by finding the local maximum among all the possible subarrays. Once we know that the local maximum is 3, we carry that information forwards to solving the next sub-problem: the maximum subarray at <code class="language-plaintext highlighter-rouge">nums[5]</code>. At <code class="language-plaintext highlighter-rouge">nums[5]</code>, we don’t need to completely redo all the summation calculations for the sums of the subarrays - we can just add the value of <code class="language-plaintext highlighter-rouge">nums[5] = 2</code> to the local maximum of the previous step, that is we can add 2 + 3 = 5.</p>
<p><img src="/images/2022-09-04-KadanesAlgo-fig1.png" alt="Fig 1" title="Figure 1" /> <br />
Figure 1 - Source [3]</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 One thing I am trying to get better at while solving Leetcode problems is taking the time to really think about the problem, and to play around with solving a couple numerical examples by hand, before coding up a solution. I think it’s probably better in the long run to think my way to the right answer instead of coding/debugging my way to a poorly-implemented answer.</p>
<h2 id="references">References</h2>
<p>[1] “Maximum subarray problem.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Maximum_subarray_problem">https://en.wikipedia.org/wiki/Maximum_subarray_problem</a> Visited 4 Sept 2022.</p>
<p>[2] “Joseph Born Kadane.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Joseph_Born_Kadane">https://en.wikipedia.org/wiki/Joseph_Born_Kadane</a> Visited 4 Sept 2022.</p>
<p>[3] Singhal, R. “Kadane’s Algorithm — (Dynamic Programming) — How and Why does it Work?” Medium. 31 Dec 2018. <a href="https://medium.com/@rsinghal757/kadanes-algorithm-dynamic-programming-how-and-why-does-it-work-3fd8849ed73d">https://medium.com/@rsinghal757/kadanes-algorithm-dynamic-programming-how-and-why-does-it-work-3fd8849ed73d</a> Visited 4 Sept 2022.</p>Today I was working through a dynamic programming problem in Leetcode when I encountered a reference to an algorithm that I had not heard of before: Kadane’s Algorithm. I wanted to write a quick post to explain this algorithm in more detail, and hopefully this will help me solve more Leetcode problems, too. I’ll start by describing the problem that Kadane’s Algorithm solves, and some of the contexts where it appears in practice. Then I will describe a number of ways to solve it, including Kadane’s Algorithm.Bit Manipulation2022-08-20T00:00:00+00:002022-08-20T00:00:00+00:00http://sassafras13.github.io/BitManipulation<p>I’m still working on my coding problems, and in this post I want to talk about <strong>bit manipulation</strong> in Python. I’ll start by reviewing binary representations in computers and then dive into talking about binary operators and bitmasks, two key tools for manipulating bits on a computer. Let’s get started!</p>
<h2 id="why-do-computers-speak-in-binary">Why do Computers Speak in Binary?</h2>
<p>Some of this is probably review for a lot of folks, but I just wanted to set the context of this post by talking about why computers use binary to encode information. I think it’s interesting to consider what drove the decision to use binary and how it impacted everything else about how computers work. Binary is a form of <strong>positional notation</strong>, because the position of each digit helps to convey information about the value of that digit [1]. For example, in base-10, 1.0 and 10.0 are different values because the position of the 1 has changed. Binary is base-2, that is, each shift in position increases the value of the number by a power of 2. For example, 1 and 10 in binary represent 1 and 2 in our base-10 system.</p>
<p>The reason why computers use base-2 instead of base-10 is that base-2 is more robust to hardware issues. The base-2 system only has 2 digits, 1 and 0, and this correlates nicely with having an electrical signal that is either on or off. If we used base-10, then computers would need to be able to detect 10 different voltage levels in the electrical signals, and it can be difficult to build hardware that can do that reliably and robustly, especially in the presence of noise. But base-2 is much easier to implement in hardware because we just need to know if the electrical signal is on or off [1].</p>
<p>At the most fundamental levels in a computer, all information is encoded in the binary format. For example, the text that you are reading right now is comprised of characters that each have a unique binary encoding [1]. So given how essential binary is to how computers work, you can understand why being able to manipulate bits (each bit is one decimal place in base-2) is important in computer science.</p>
<h2 id="some-binary-basics">Some Binary Basics</h2>
<p>Before we move on to talk about bitwise operators, I want to cover some important aspects of how binary encoding works, specifically when implemented in Python. First of all, generally speaking we represent data in <strong>bytes</strong>, where 1 byte consists of 8 bits. Recall that each bit is a decimal place in the base-2 numbering system, so the maximum value that 1 byte can represent is $2^8 = 256$.</p>
<p>It is also possible to distinguish between positive and negative numbers in binary. Typically, this is done using an extra bit at the far left of the byte which is 0 for positive values and 1 for negative values. This is shown in Figure 1. Note that Python does not follow this convention, however, and so that affects some of the bitwise operators we will see in the next section. Specifically, Python allows for using an unlimited number of bits to represent an integer, so this forces it to use a different method for representing the corresponding sign of the integer [1].</p>
<p><img src="/images/2022-08-13-BitManipulation-fig0.png" alt="Fig 1" title="Figure 1" /> <br />
Figure 1 - Source [1]</p>
<p>Note that Python allows you to see how integers are represented in binary and other schemes. For example [1]:</p>
<pre><code class="language-python3">bin(42) = 0b101010 # 0b is the prefix indicating the representation is binary
hex(42) = 0x2a # 0x is the prefix for the hexadecimal system
oct(42) = 0o52 # 0o is the prefix for the octal system
</code></pre>
<h2 id="twos-complement">Two’s Complement</h2>
<p>For languages other than Python, integers are typically stored in a form known as two’s complement. In two’s complement, the positive integers are stored as expected in binary, and negative integers are stored as the <strong>two’s complement</strong> of its absolute value. Two’s complement uses a sign bit (i.e. the left-most bit) to indicate if the number is positive (0) or negative (1) [2].</p>
<p>Let’s understand what the two’s complement is. The definition of the two’s complement is that, for an N-bit number, it is the complement of that number with respect to $2^N$. For example, let’s try to represent $-3$ in two’s complement notation. We will use 4 bits to encode $-3$: one bit is for the sign, the other 3 bits are for encoding the value. Therefore we have $N = 3$ bits and so we need to find the complement of 3 with respect to $2^3 = 8$. The complement of 3 with respect to 8 is $8 - 3 = 5$. In binary, we can write 5 as 101. So, in full, the two’s complement notation for -3 is 1101 (the left-most bit is 1 to indicate the number is negative) [2].</p>
<p>Put more generally, if we want to find the two’s complement of a negative integer $-K$, then we need to perform [2]:</p>
<pre><code class="language-python3">concat(1, 2**(N-1) - K)
</code></pre>
<p>Another way to find the two’s complement of a negative integer is to [2]:</p>
<ul>
<li>Flip all the bits</li>
<li>Add 1</li>
<li>Prepend 1 to indicate the value is negative</li>
</ul>
<h2 id="bitwise-operators">Bitwise Operators</h2>
<p>We can operate on individual bits using a number of different logical operators in Python. A summary of some of the common operators is shown in Figure 2. Most of the operators (with the exception of the NOT operator) are <strong>binary</strong>, which means that they compare the left <strong>operand</strong> to the right operand - for example they determine if <em>a</em> and <em>b</em> are true. The bitwise NOT operator is a <strong>unary</strong> operator since it only takes in one operand [1].</p>
<p><img src="/images/2022-08-13-BitManipulation-fig1.png" alt="Fig 2" title="Figure 2" /> <br />
Figure 2 - Source [1]</p>
<p>All of the binary bitwise operators have an accompanying <strong>compound operator</strong>, as shown in Figure 3 [1]. The compound operator performs the bitwise operation, and then assigns the result of that operation to the left operand. I find that the compound operator reminds me of incrementing a variable: <code class="language-plaintext highlighter-rouge">i += 1</code> because I am taking the variable <code class="language-plaintext highlighter-rouge">i</code>, adding 1, and assigning the new value to <code class="language-plaintext highlighter-rouge">i</code>.</p>
<p><img src="/images/2022-08-13-BitManipulation-fig2.png" alt="Fig 3" title="Figure 3" /> <br />
Figure 3 - Source [1]</p>
<p>Let’s explore some of these operators in more detail. I would also recommend heading over to [1] to see some very helpful GIFs that make it easy to visualize what these operators are doing.</p>
<h3 id="and">AND</h3>
<p>The bitwise AND operator compares two operands and returns a 1 any time the bit in the same position in both operands is on. Otherwise, it returns 0 [1]. An example of this is shown in Figure 4.</p>
<p><img src="/images/2022-08-13-BitManipulation-fig3.png" alt="Fig 4" title="Figure 4" /> <br />
Figure 4 - Source [1]</p>
<h3 id="or">OR</h3>
<p>The bitwise OR operator returns a 1 any time at least one of the operands has a 1 in a given position. The OR operator only returns 0 if both bits (from each operand) are zero [1]. This is shown in Figure 5.</p>
<p><img src="/images/2022-08-13-BitManipulation-fig4.png" alt="Fig 5" title="Figure 5" /> <br />
Figure 5 - Source [1]</p>
<h3 id="xor">XOR</h3>
<p>Interestingly, Python does not natively support the XOR operator. XOR stands for “exclusive or.” The XOR operator returns a 1 if the two operands have different values at the same bit position. It basically tells us when two operands represent two mutually exclusive cases. For example, XOR would return a 1 if one operand had a value of 0 and the other had a value of 1. If both operands have the same value, then XOR returns 0 [1]. This is shown in Figure 6.</p>
<p><img src="/images/2022-08-13-BitManipulation-fig5.png" alt="Fig 6" title="Figure 6" /> <br />
Figure 6 - Source [1]</p>
<h3 id="not">NOT</h3>
<p>Finally, the NOT operator simply reverses the value of every bit in the operand. However, it is worth noting that the NOT operator does not always work as expected in Python, because Python uses unsigned integers [1]. For now, admire this image depicting how the NOT operator works.</p>
<p><img src="/images/2022-08-13-BitManipulation-fig6.png" alt="Fig 7" title="Figure 7" /> <br />
Figure 7 - Source [1]</p>
<h3 id="shift-operators">Shift Operators</h3>
<p>There is another category of bitwise operations that shift binary numbers to the left and right. This is a very efficient way of manipulating numbers and can be used to make bitmasks, which we’ll discuss later. The left shift operator, <code class="language-plaintext highlighter-rouge"><<</code>, moves the bits of the first operand as many places as specified in the second operand [1]. For example:</p>
<pre><code class="language-python3">1100101 << 2 = 110010100
</code></pre>
<p>Every time we shift the operand one place to the left, we double its value, as shown in the table below. Bit shifting in this way used to be a popular way to quickly compute products or exponents, but Python is now very efficient and doing this bit manipulation manually is unnecessary. Notice also that Python is intelligent enough to know that if you shift the first operand to a size greater than a byte (8 bits) that we should increase the storage size of that number so we can represent it accurately. However, other languages may not do this and so you might find that you shift a value to the left and some of the bits are chopped off, returning a smaller value than you wanted [1].</p>
<p><img src="/images/2022-08-13-BitManipulation-fig7.png" alt="Fig 8" title="Figure 8" /> <br />
Figure 8 - Source [1]</p>
<p>Shifting values to the right instead of the left is the same operation, but now every time you shift the bits one place, the value of the number is halved instead of doubled, as shown in Figure 9. Notice also that when shifting (i.e. dividing) an odd number, the value is rounded down to the nearest integer (i.e. we do floor division) [1].</p>
<p><img src="/images/2022-08-13-BitManipulation-fig8.png" alt="Fig 9" title="Figure 9" /> <br />
Figure 9 - Source [1]</p>
<p>One thing to notice here is that the right shift operator can affect the sign of a number. Recall that the left-most bit is often used to convey the sign of the number the byte represents, and so if you right shift the bits, the left-most bit always becomes a 0 [1].</p>
<p>So far, the left and right shifts we have discussed are called <strong>logical shift operators</strong>, which means that they just move the bits without taking into consideration things like the sign of the integers that are being manipulated. There is another category of shift operators - <strong>arithmetic shift operators</strong> - that performs the right shift while taking into account the sign of the integer. The arithmetic right shift, therefore, maintains the sign bit when it shifts the rest of the bits to the right [1].</p>
<h2 id="bitmasks">Bitmasks</h2>
<p>Bitmasks can be used to manipulate specific bits in a value. For example, you can force Python to represent the sign of an integer using the left-most bit by manipulating the integer with a <strong>bitmask</strong>. For example, if you perform the AND operation on an integer and a corresponding bitmask, you get the following [1]:</p>
<pre><code class="language-python3">mask = 0b11111111 # equivalent to 0xff
bin(-42 & mask) = 0b11010110
</code></pre>
<p>Notice here how the bitmask is overlaid on top of the value in question to obtain some information about its binary representation. There are some other useful operations that we can perform with bitmasks as described below.</p>
<h3 id="getting-a-bit">Getting a Bit</h3>
<p>If you want to read the bit at a specific position in a number, you can use a bitmask to obtain it as shown below [1]:</p>
<pre><code class="language-python3">def get_bit(value, bit_index):
return value & (1 << bit_index)
get_bit(0b10000000, bit_index=5) = 0
get_bit(0b10100000, bit_index=5) = 32
</code></pre>
<p>Here, the bitmask is <code class="language-plaintext highlighter-rouge">1</code> shifted over by <code class="language-plaintext highlighter-rouge">bit_index</code> number of places. The function then returns the value (in base-10) of the bit at that index. If you wanted to get the normalized value (i.e. just a 1 or 0), then you could use this function instead [1]:</p>
<pre><code class="language-python3">def get_normalized_bit(value, bit_index):
return (value >> bit_index) & 1
</code></pre>
<p>Here we right-shift <code class="language-plaintext highlighter-rouge">value</code> until the bit in question is at the first position, and then we retrieve it using a bitmask of <code class="language-plaintext highlighter-rouge">1</code>.</p>
<h3 id="setting-and-unsetting-a-bit">Setting and Unsetting a Bit</h3>
<p>You may also want to set the value of a specific bit, which you can do using a similar function to the one shown above, but using the logical OR operator instead of AND [1]:</p>
<pre><code class="language-python3">def set_bit(value, bit_index):
return value | (1 << bit_index)
set_bit(0b10000000, bit_index=5) = 160
bin(160) = '0b10100000'
</code></pre>
<p>Since we are using the OR operator with 1 as the right operand, we are setting the bit at <code class="language-plaintext highlighter-rouge">bit_index</code> to have a value of 1 regardless of its current value. If you wanted to unset this value, you could do the inverse (i.e. use the NOT operator on the bitmask) [1]:</p>
<pre><code class="language-python3">def clear_bit(value, bit_index):
return value & ~(1 << bit_index)
clear_bit(0b11111111, bit_index=5) = 223
bin(223) = '0b11011111'
</code></pre>
<p>The NOT operator in Python always returns a negative number when given a positive number as input, but using the NOT operator in concert with the AND operator changes the way Python represents the bitmask so that it works as expected in this situation. (For more detail on why this is the case, see the section on “Binary Number Representations” in [1].)</p>
<h3 id="toggling-a-bit">Toggling a Bit</h3>
<p>Finally, you can toggle a bit to have the opposite value to its current state using the XOR operator and a bitmask [1]:</p>
<pre><code class="language-python3">def toggle_bit(value, bit_index):
return value ^ (1 << bit_index)
</code></pre>
<h2 id="conclusion">Conclusion</h2>
<p>I hope this was a helpful overview of how data is represented in a computer and how to manipulate individual bits. I expect that as I tackle some problems related to bitwise manipulation that I will encounter some more concepts that I need to learn, so stay tuned for further blog posts in this space.</p>
<h2 id="references">References</h2>
<p>[1] Zaczynski, B. “Bitwise Operators in Python.” Real Python. <a href="https://realpython.com/python-bitwise-operators/">https://realpython.com/python-bitwise-operators/</a> Visited 13 Aug 2022.</p>
<p>[2] Laakmann McDowell, G. Cracking the Coding Interview, 6th edition. 2016. CareerCup, LLC.</p>I’m still working on my coding problems, and in this post I want to talk about bit manipulation in Python. I’ll start by reviewing binary representations in computers and then dive into talking about binary operators and bitmasks, two key tools for manipulating bits on a computer. Let’s get started!Dynamic Programming2022-08-02T00:00:00+00:002022-08-02T00:00:00+00:00http://sassafras13.github.io/DynamicProgramming<p>I’m getting back into the groove of studying for a coding interview, and I recently came across a question that requires knowledge of <strong>dynamic programming</strong> to solve it. I’ve touched on dynamic programming before, both in the context of <a href="https://sassafras13.github.io/Silver3/">optimal control</a> and when discussing <a href="https://sassafras13.github.io/MessagePassingAlgos/">message passing algorithms over graphs</a>. This post will review that material and also give a more computer science-focused introduction to the topic.</p>
<h2 id="what-is-dynamic-programming">What is Dynamic Programming?</h2>
<p>The term dynamic programming*1 is used to describe an approach that breaks problems up into overlapping sub-problems before solving them. In order for this approach to be valid, we must be sure that the <strong>Principle of Optimality</strong> applies. The Principle of Optimality says that if the solution to every sub-problem is optimal, then the overall solution built on those components is <em>also</em> optimal [1, 4]. Basically, this Principle is the thing that ensures breaking the problem into pieces is an acceptable approach.</p>
<p>One example of a problem that can be solved using dynamic programming is computing entries in the Fibonacci sequence. As shown in Figure 1, if we want to compute the fourth entry in the sequence, we can do this by breaking up the problem into computing the earlier entries following a tree structure.</p>
<p><img src="/images/2022-08-02-DynamicProgramming-fig1.png" alt="Fig 1" title="Figure 1" /> <br />
Figure 1 - Source [4]</p>
<p>Dynamic programming can refer to a range of different solutions, and we can broadly categorize them into 2 types of solutions: <strong>top-down</strong> and <strong>bottom-up</strong> solutions [2]. A top-down solution is essentially the same as <a href="https://sassafras13.github.io/recursion/">recursion</a> - we break the problem into sub-problems recursively and then solve those sub-problems. Conversely, the bottom-up approach solves the smallest sub-problem first and builds those up into a complete solution [2]. Typically, recursion is not time or memory-efficient as compared to bottom-up approaches [3]. We will explore these two approaches in more detail in the next sections.</p>
<h2 id="top-down-dynamic-programming">Top-Down Dynamic Programming</h2>
<p>The top-down, recursive approach to solving a problem like the Fibonacci problem is shown in Figure 1. Here, we continue to call the <code class="language-plaintext highlighter-rouge">Fib( )</code> function until we reach the base case, then we use the base case solutions to solve the other intermediate function calls. We can see this in the code below [4]:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">callFib</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="k">if</span> <span class="n">n</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">return</span> <span class="n">n</span>
<span class="k">return</span> <span class="n">callFib</span><span class="p">(</span><span class="n">n</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">callFib</span><span class="p">(</span><span class="n">n</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="n">callFib</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
</code></pre></div></div>
<p>As we said above, vanilla top-down dynamic programming is not very efficient. If you look at the recursion tree in Figure 1, we make the same call to <code class="language-plaintext highlighter-rouge">Fib(1)</code> three times, for example. To save time, we can use a technique called <strong>memoization</strong> *2 to store the value of <code class="language-plaintext highlighter-rouge">Fib(1)</code> so that we don’t waste time computing it multiple times [4].</p>
<h2 id="bottom-up-dynamic-programming">Bottom-Up Dynamic Programming</h2>
<p>In bottom-up dynamic programming, we identify the smallest sub-problem and use it as a starting point to solve progressively larger sub-problems until we’ve arrived at the solution that we needed. In the case of solving for the n-th entry in the Fibonacci sequence, we start by solving <code class="language-plaintext highlighter-rouge">Fib(0)</code> and work up to solving <code class="language-plaintext highlighter-rouge">Fib(n)</code>. We store every entry in a table along the way, so we often say that we are using <strong>tabulation</strong> to store the solutions to the sub-problem [4].</p>
<p>Notice that having the table then makes solving for other entries in the sequence much easier, because we can now just look up the solutions in the table rather than re-computing the solution as we would do with recursion [4].</p>
<p>In general, memoization is easier to code than tabulation, because memoization essentially uses a wrapper function to save intermediate values, as follows [4,5]:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">memoizeCallFib</span><span class="p">(</span><span class="n">n</span><span class="p">):</span>
<span class="n">memo</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="o">+</span><span class="mi">1</span><span class="p">)]</span>
<span class="k">return</span> <span class="n">callFib</span><span class="p">(</span><span class="n">memo</span><span class="p">,</span> <span class="n">n</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">callFib</span><span class="p">(</span><span class="n">memo</span><span class="p">,</span> <span class="n">n</span><span class="p">):</span>
<span class="k">if</span> <span class="n">n</span> <span class="o"><</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">return</span> <span class="n">n</span>
<span class="k">if</span> <span class="n">memo</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">>=</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">return</span> <span class="n">memo</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>
<span class="n">memo</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">callFib</span><span class="p">(</span><span class="n">memo</span><span class="p">,</span> <span class="n">n</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">callFib</span><span class="p">(</span><span class="n">memo</span><span class="p">,</span> <span class="n">n</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span>
<span class="k">return</span> <span class="n">memo</span><span class="p">[</span><span class="n">n</span><span class="p">]</span>
<span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="n">memoizeCallFib</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
</code></pre></div></div>
<p>However, note that the memoization approach requires saving a new entry in an array on every call - if we are using a top-down approach to go through many, many layers of sub-problems, then this can be a huge slow-down in the algorithm. The memo stack will also grow to be extremely large, proportional to the number of sub-problems we solved. At this point, using a tabulation approach may be better because the table will save the sub-problem solutions once, efficiently. In some cases, it will be more complicated to code these table-based methods because we have to define the table architecture and how we’ll find entries in a certain order, but the hard work may pay out in the long run [5].</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 According to David Silver, the term dynamic programming comes from the following: the word “dynamic” implies that the problem we are solving has some sequential or temporal nature to it, and the term “programming” here is meant in the mathematical sense. That is, we are writing a “program” or a “policy” for solving a problem, not a computer program [1].</p>
<p>*2 The value of <code class="language-plaintext highlighter-rouge">Fib(1)</code> is the “memo” that we store in “memoization”.</p>
<h2 id="references">References</h2>
<p>[1] Silver, D. “RL Course by David Silver - Lecture 3: Planning by Dynamic Programming.” YouTube. 13 May 2015. <a href="https://www.youtube.com/watch?v=Nd1-UUMVfz4&list=PLqYmG7hTraZDM-OYHWgPebj2MfCFzFObQ&index=3">https://www.youtube.com/watch?v=Nd1-UUMVfz4&list=PLqYmG7hTraZDM-OYHWgPebj2MfCFzFObQ&index=3</a> Visited 08 Aug 2020.</p>
<p>[2] “Dynamic programming.” Wikipedia. https://en.wikipedia.org/wiki/Dynamic_programming Visited 13 July 2022.</p>
<p>[3] “Improving efficiency of recursive functions.” Khan Academy. <a href="https://www.khanacademy.org/computing/computer-science/algorithms/recursive-algorithms/a/improving-efficiency-of-recursive-functions">https://www.khanacademy.org/computing/computer-science/algorithms/recursive-algorithms/a/improving-efficiency-of-recursive-functions</a> Visited 11 Jan 2022.</p>
<p>[4] “Grokking Dynamic Programming Patterns for Coding Interviews.” Educative.io. <a href="https://www.educative.io/courses/grokking-dynamic-programming-patterns-for-coding-interviews/m2G1pAq0OO0">https://www.educative.io/courses/grokking-dynamic-programming-patterns-for-coding-interviews/m2G1pAq0OO0</a> Visited 3 Aug 2022.</p>
<p>[5] Bee. “What is Dynamic Programming with Python Examples.” Skerritt.blog. 31 Dec 2019. <a href="https://skerritt.blog/dynamic-programming/">https://skerritt.blog/dynamic-programming/</a> Visited 3 Aug 2022.</p>I’m getting back into the groove of studying for a coding interview, and I recently came across a question that requires knowledge of dynamic programming to solve it. I’ve touched on dynamic programming before, both in the context of optimal control and when discussing message passing algorithms over graphs. This post will review that material and also give a more computer science-focused introduction to the topic.Message Passing Algorithms2022-07-14T00:00:00+00:002022-07-14T00:00:00+00:00http://sassafras13.github.io/MessagePassingAlgos<p>I have been learning about how to perform inference - i.e. how to make predictions - using probabilistic graphical models. Recently, I have written about <a href="https://sassafras13.github.io/InferencePGMs/">inference in general</a>, as well as <a href="https://sassafras13.github.io/Sampling/">sampling methods for approximate inference</a>. In this post, I want to explore another set of algorithms for performing exact inference using <strong>message passing</strong>.</p>
<h2 id="variable-elimination-revisited">Variable Elimination Revisited</h2>
<p>Previously, we saw that we can use <a href="https://sassafras13.github.io/GNN/">variable elimination to compute marginal probabilities</a>. The downside to this approach is that if we changed the marginal probability of interest (i.e. instead of computing \(P(X_1)\) we wanted \(P(X_2)\)) then we would have to completely restart the algorithm and waste a lot of our computational resources. Specifically, we are wasting the fact that we computed a lot of intermediate probabilities when calculating \(P(X_1)\) that are still useful when we need to find \(P(X_2)\). We can call these <strong>intermediate factors</strong>, \(\tau\). If we can cache these values and develop an algorithm that can reuse them, we will come up with a more efficient approach to performing inference than classic variable elimination. We will call this new approach the <strong>junction tree algorithm</strong>, for reasons that will become obvious in a moment [1].</p>
<hr />
<h3 id="side-note-on-dynamic-programming">Side Note on Dynamic Programming</h3>
<p>Variable elimination - the process of marginalizing out variables in a specific order to eliminate all but the desired variable - is a form of <strong>dynamic programming</strong>. I’ve written about dynamic programming before in the context of <a href="https://sassafras13.github.io/Silver3/">reinforcement learning</a> but let me briefly define it again here. Dynamic programming, particularly in the context of computer science, is an approach that breaks problems up into overlapping subproblems, where it is possible to find the optimal solution to each of these subproblems [2, 3]. Within dynamic programming, we can solve a problem either using a <strong>top-down approach</strong> or a <strong>bottom-up approach</strong> [2].</p>
<p>The top-down approach is essentially the result of using recursion to solve a problem - we break the problem into subproblems (recursively) and then solve each subproblem. If the solution to the subproblem already exists (i.e. we cached it) then we just pull it directly from memory, saving computation time. Conversely, the bottom-up approach solves the smallest subproblems first and builds up to solutions for progressively bigger subproblems [2].</p>
<p>I mention this because we can relate both variable elimination and the junction tree algorithm to top-down vs. bottom-up dynamic programming approaches. Variable elimination computes the marginal probability of just one variable in a top-down fashion, which is why it is computationally wasteful if we want to be able to compute the marginal probabilities of many different random variables in our model. On the other hand, the junction tree algorithm is more of a bottom-up method because we will calculate and store all of the answers to the subproblems on the way to computing the marginal probability for a particular random variable [1].</p>
<hr />
<p>Okay, now that we’ve covered that, let’s return to this new junction tree algorithm. The junction tree (JT) algorithm will make 2 passes through the graphical model using the variable elimination approach and store all of the intermediate factors, \(\tau\), in a table. We can use this table to obtain any marginal probability in constant time, that is in \(\mathcal{O}(1)\) [1]. Within the class of JT algorithms there are two forms: a belief propagation (BP) method and the full JT method [1].</p>
<h2 id="belief-propagation">Belief Propagation</h2>
<p>Let’s start by recapping how variable elimination is equivalent to message passing in a simple example. Let’s imagine we have a small tree graph*1 and we want to compute the marginal probability \(p(x_i)\) [1]. To do this, we can set the root of the tree to be the node \(x_i\) and process the nodes in order from the leaves of the tree to the root. Since we are working with a tree structure (where there is only one path from node to node), the maximum clique size that can be formed at any point during this process is of size 2. So for example, if we are eliminating the intermediate node \(x_j\) we write [1]:</p>
\[\tau_k(x_k) = \sum_{x_j} \phi(x_k, x_j) \tau_j(x_j)\]
<p>Where \(x_k\) is the parent of \(x_j\) [1]. These factors \(\tau\) are <strong>messages</strong> that are sent from child to parent (i.e. from \(x_j\) to \(x_k\)) which contain all relevant information about the subtree that has \(x_j\) as the root [1]. A visualization of this process is shown in Figure 1.</p>
<p><img src="/images/2022-07-14-MessagePassingAlgos-fig1.png" alt="Fig 1" title="Figure 1" /> <br />
Figure 1 - Source [1]</p>
<p>Once the root has received all the messages from its own children, then we marginalize them out to get the final marginal probability for \(x_i\). If we now wanted to repeat this process for a different node, i.e. to compute \(p(x_k)\), we would rotate the tree so that \(x_k\) was the root, and then run the algorithm again. Notice that the messages that \(x_j\) sent to \(x_k\) are the same as before, so if we had stored that value from the first time we ran the VE algorithm, we would be able to reuse it now. Notice that each node sends a message only when it has finished receiving messages from all of its children [1].</p>
<p>Now that we have the basic idea, we can formally define the belief propagation algorithm for completing two different tasks [1]:</p>
<ul>
<li><strong>Sum-product message passing</strong>: for marginal inference, \(p(x_i)\)</li>
<li><strong>Max-product message passing</strong>: for MAP inference, \(\max_{x_1, …, x_n} p(x_1, …, x_n)\)</li>
</ul>
<h3 id="sum-product-message-passing">Sum-Product Message Passing</h3>
<p>For sum-product message passing, we can compute the marginal probability for node \(x_i\) as [1]:</p>
\[p(x_i) \propto \phi(x_i) \prod_{l \in N(i)} m_{l \rightarrow i} (x_i)\]
<p>Where \(m_{l \rightarrow i} (x_i)\) is the message from node \(i\) to node \(j\), and it can be written as [1]:</p>
\[m_{i \rightarrow j}(x_j) = \sum_{x_i} \phi(x_i) \phi(x_i, x_j) \prod_{l \in N(i) \backslash j} m_{l \rightarrow i}(x_i)\]
<p>The product in the expression above is computed for all nodes that are neighbors of \(i\) except for node \(j\). This is exactly the same as the messages that we would pass in our conceptual discussion earlier.</p>
<h3 id="max-product-message-passing">Max-Product Message Passing</h3>
<p>The logic that applied to sum-product message passing still applies here because we can distribute max operators over products just as we can distribute sums over products. So if you want to compute the partition function of a chain Markov Random Field (MRF) model for the MAP inference [1]:</p>
\[\mathcal{Z} = \sum_{x_1} … \sum_{x_n} \phi(x_1) \prod_{i=2}^n \phi(x_i, x_{i-1})\]
\[\mathcal{Z} = \sum_{x_n} \sum_{x_{n-1}} \phi( x_n, x_{n-1}) \sum_{x_{n-2}} \phi ( x_{n-1}, x_{n-2}) … \sum_{x_1} \phi(x_2, x_1) \phi(x_1)\]
<p>And if you want the maximum value of the joint, unnormalized probability distribution (called \(\tilde{p}^*\)) then we can write:</p>
\[\tilde{p}^* = \max_{x_1} … \max_{x_n} \phi(x_1) \prod_{i=2}^n \phi(x_i, x_{i-1})\]
\[\tilde{p}^* = \max_{x_n} \max_{x_{n-1}} \phi( x_n, x_{n-1}) \max_{x_{n-2}} \phi(x_{n-1}, x_{n-2}) … \max_{x_1} \phi(x_2, x_1) \phi(x_1)\]
<p>This math works for factor trees as well as chain graphs. Notice that if you wanted to know the argmax operator of the probability distribution (i.e. the most likely assignments to the variables \(x\)), then you could keep track of the best assignments of each variable \(x_i\) during this algorithm and refer back to them at the end [1].</p>
<h2 id="junction-tree-algorithm">Junction Tree Algorithm</h2>
<p>So far we have seen how the belief propagation algorithm works on tree graphs. However, if we do not have a tree graph structure, we need to make one! Having a tree-like structure makes the message passing algorithm feasible - without it, the computations could become intractable. The full JT algorithm that we are going to present here will take an arbitrary graph and cluster the variables together, so that the <em>clustered</em> version of the graph has a tree structure [1]. Then we can apply the same belief propagation method as before, assuming that the local clusters can be solved exactly [1].</p>
<h3 id="an-example">An Example</h3>
<p>Let’s consider a simple example where we want to perform marginal inference over a small MRF graph which can be written as [1]:</p>
\[p(x_1, …, x_n) = \frac{1}{Z} \prod_{c \in \mathcal{C}} \phi_c (x_c)\]
<p>The cliques in the MRF must satisfy the <strong>running intersection property</strong> (RIP). Specifically, the cliques must have some “path structure” where there is a sequential ordering to them. If we have cliques \(x_c^{(1)}, …, x_c^{(n)}\), then if a variable is in both the \(j\)-th and \(k\)-th cliques, i.e. \(x_i \in x_c^{(j)}\) and \(x_i \in x_c^{(k)}\) then that variable must also be in any other cliques along the path, such as the \(l\)-th clique, i.e. \(x_i \in x_c^{(l)}\) if \(x_c^{(l)}\) is along the path between \(x_c^{(j)}\) and \(x_c^{(k)}\) [1].</p>
<p><img src="/images/2022-07-14-MessagePassingAlgos-fig2.png" alt="Fig 2" title="Figure 2" /> <br />
Figure 2 - Source [1]. Notice that the round nodes are cliques, containing the variables within their scope. The rectangular nodes indicate the <strong>sepsets</strong>, the groups of variables at the intersection of the scopes of two neighboring cliques [1].</p>
<p>If our MRF graph looks like the one shown in Figure 2, then we can start to compute the marginal probability of \(x_1\) as follows [1]:</p>
\[\tilde{p}(x_1) = \phi(x_1) \sum_{x_2} \phi(x_1, x_2) \sum_{x_3} \phi(x_1, x_2, x_3) \sum_{x_5} \phi(x_2, x_3, x_5) \sum_{x_6} \phi( x_2, x_5, x_6)\]
<p>Notice that we were allowed to push the sums into the product because the RIP assumption tells us that \(x_6\), for example, can only exist in that final cluster and not any of the earlier ones. Each intermediate factor \(\tau\) marginalizes out the variables that are not included in the scope of the next cluster. The marginalization, \(\tau(x_2, x_3, x_5) = \phi(x_2, x_3, x_5) \sum_{x_6} \phi(x_2, x_5, x_6)\), is a message being shared from one cluster to another [1].</p>
<p>The full JT algorithm simply adds a step to convert any graph into a tree of clusters before performing the same message passing approach that we saw earlier. Note that this “tree of clusters” is called a <strong>junction tree</strong>, hence the name of the algorithm. Let’s get some mathematical clarity on what a junction tree is. First, if we had some undirected graphical model*2 \(G = (X, E_G)\), then the junction tree \(T = (C, E_T)\) is a tree over \(G\). The nodes of \(T\), \(c \in C\), are associated with subsets \(x_c\) of the nodes in the graph \(G\), that is \(x_c \subseteq X\) [1].</p>
<p>The junction tree has two specific properties [1, 5]:</p>
<ul>
<li>
<p><strong>Family preservation</strong>: For each factor in the tree, \(\phi\), there is a cluster \(c\) such that the scope of the cluster is contained within the subset \(x_c\), that is \(\text{Scope}[\phi] \subseteq x_c\).</p>
</li>
<li>
<p><strong>Running intersection</strong>: As described above, for every pair of clusters \(c^{(i)}\) and \(c^{(j)}\), the clusters on the path between them must contain \(x_c^{(i)} \cap x_c^{(j)}\).</p>
</li>
</ul>
<p>The optimal trees have small, modular clusters, but it is NP-hard to find the globally optimal tree, unless the original graph itself, \(G\), is already a tree [1].</p>
<h3 id="the-full-jt-algorithm">The Full JT Algorithm</h3>
<p>The full JT algorithm assumes that we begin with a junction tree. We have clusters in the tree, and each cluster has a potential, \(\phi_c(x_c)\). The potentials are a product of all the factors in the graph \(G\) that have been assigned to cluster \(c\). This gives us a probability distribution over the entire graph that is the normalized product of all the clusters (this is the same as our usual UGM form) [1].</p>
<p>For each step in the algorithm, we choose two adjacent clusters, \(c^{(i)}\) and \(c^{(j)}\), in the tree and calculate a message to pass between them. Note that the scope of the message is the sepset \(S_{ij}\) between the two clusters (i.e. the scope is the set of nodes that they have in common) [1]:</p>
\[m_{i \rightarrow j}(S_{ij}) = \sum_{x_c \backslash S_{ij}} \phi_c (x_c) \prod_{l \in N(i) \backslash j} m_{l \rightarrow i}(S_{li})\]
<p>Notice that \(c^{(i)}\) cannot send this message until it has received all the messages from its neighbors except cluster \(c^{(j)}\). After this algorithm finishes computing all the messages, we can define the <strong>belief</strong> of each cluster based on all the messages that it received [1]:</p>
\[\beta_c (x_c) = \phi_c (x_c) \prod_{l \in N(i)} m_{l \rightarrow i}(S_{li})\]
<p>This belief is proportional to the marginal probability over the scope of this particular cluster. So if we wanted the unnormalized probability \(\tilde{p}(x)\) for some variable \(x \in x_c\), we could marginalize out the other variables thus [1]:</p>
\[\tilde{p}(x) = \sum_{x_c \backslash x} \beta_c (x_c)\]
<p>To normalize this probability, we can compute the partition function \(\mathcal{Z}\) as the sum overall the beliefs in a cluster [1]:</p>
\[\mathcal{Z} = \sum_{x_c} \beta_c (x_c)\]
<p>This algorithm runs best with small clusters, because the running time scales exponentially in the size of the largest cluster.</p>
<h2 id="shafer-shenoy-and-hugin-algorithms">Shafer-Shenoy and HUGIN Algorithms</h2>
<p>Now that we have seen the full junction tree algorithm, I wanted to note that there are also two variations on the implementation of it. Specifically, there is the <strong>Shafer-Shenoy algorithm</strong> which is exactly the full JT algorithm describe above, and the <strong>HUGIN algorithm</strong>, which is a variant of what we’ve seen. The HUGIN algorithm recognizes that whenever messages are sent or beliefs are calculated, the same messages are multiplied together several times. Therefore, the HUGIN algorithm will cache the intermediate products as it runs to save more computation time. Since the HUGIN algorithm is caching multiplication products, it will also have to divide back out some messages at certain points [6].</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 A tree structure specifically means that any pair of nodes in the graph are connected by exactly one <strong>path</strong>. This means that a tree structure is acyclic. A tree is also an <strong>undirected</strong> graph [4].</p>
<p>*2 If you were starting with a directed graphical model you could moralize it first to obtain an undirected model, that’s fine [1]. Moralizing means that you follow a set of rules to convert the directed edges to undirected edges. It follows this kind of outdated metaphor of marrying parents that point to the same child node by adding edges between them, hence the weird name.</p>
<h2 id="references">References</h2>
<p>[1] Kuleshov, V. and Ermon, S. “Junction Tree Algorithm.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/jt/">https://ermongroup.github.io/cs228-notes/inference/jt/</a> Visited 14 July 2022.</p>
<p>[2] “Dynamic programming.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Dynamic_programming">https://en.wikipedia.org/wiki/Dynamic_programming</a> Visited 13 July 2022.</p>
<p>[3] “What is Dynamic Programming?” Educative.io. <a href="https://www.educative.io/courses/grokking-dynamic-programming-patterns-for-coding-interviews/m2G1pAq0OO0">https://www.educative.io/courses/grokking-dynamic-programming-patterns-for-coding-interviews/m2G1pAq0OO0</a> Visited 13 July 2022.</p>
<p>[4] “Tree (graph theory).” Wikipedia. <a href="https://en.wikipedia.org/wiki/Tree_(graph_theory)">https://en.wikipedia.org/wiki/Tree_(graph_theory)</a> Visited 13 July 2022.</p>
<p>[5] Urtasun, R., Hazan, T. “Probabilistic Graphical Models.” TTI Chicago, 25 Apr 2011. <a href="https://www.cs.toronto.edu/~urtasun/courses/GraphicalModels/lecture11.pdf">https://www.cs.toronto.edu/~urtasun/courses/GraphicalModels/lecture11.pdf</a> Visited 13 July 2022.</p>
<p>[6] Paskin, M. “A Short Course on Graphical Models: 3. The Junction Tree Algorithms.” <a href="https://web.archive.org/web/20150319085443/https://ai.stanford.edu/~paskin/gm-short-course/lec3.pdf">https://web.archive.org/web/20150319085443/https://ai.stanford.edu/~paskin/gm-short-course/lec3.pdf</a> Visited 14 July 2022.</p>I have been learning about how to perform inference - i.e. how to make predictions - using probabilistic graphical models. Recently, I have written about inference in general, as well as sampling methods for approximate inference. In this post, I want to explore another set of algorithms for performing exact inference using message passing.Sampling Methods for Approximate Inference2022-07-01T00:00:00+00:002022-07-01T00:00:00+00:00http://sassafras13.github.io/Sampling<p>In this post we are going to talk about <strong>sampling methods</strong> for use in performing inference over probabilistic graphical models. <a href="https://sassafras13.github.io/InferencePGMs/">Last time</a>, we talked about how sampling methods produce approximate results for inference and are best used on really complex models when it is intractable to find an exact solution in polynomial time [1]. Sampling methods essentially do what they say on the tin: they indirectly draw samples from a distribution and this assortment of samples can be used to solve many different inference problems. For example, marginal probabilities, MAP quantities, and expectations of random variables drawn from a given model can all be computed using sampling. Variational inference is often superior in accuracy, but sampling methods have been around longer and can often succeed where variational methods fail [1]. Let’s dive in to learn more about sampling.</p>
<h2 id="forward-sampling">Forward Sampling</h2>
<p>Probably the simplest sampling approach is to draw samples from a multinomial distribution, which has \(k\) possible outcomes and each one has a probability \(\theta_k\). We then draw samples from a uniform distribution [0,1] and depending on the value, we assign the outcome to the event \(k\) with the corresponding probability. Specifically, we divide the interval between 0 and 1 into \(k\) sections where each section has size \(\theta_k\), as shown in Figure 1 [1].</p>
<p><img src="/images/2022-07-01-Sampling-fig1.png" alt="Fig 1" title="Figure 1" /> <br />
Figure 1 - Source [1]</p>
<p>We can extend this idea to a simple Bayes net containing multinomial variables, where we use <strong>forward sampling</strong> to compute the joint or marginal distributions by following the topology of the graph. Specifically, we start by drawing samples from the root nodes of the graph (those with no parents) and then we draw samples from their children by conditioning the children’s conditional probability distributions on the samples drawn from their parents. We continue this process until we have sampled from all the nodes in the graph, and this can be done in polynomial time [1]. There is a variation on this approach for undirected models as well provided they can be represented as clique trees [1].</p>
<h2 id="monte-carlo-estimation">Monte Carlo Estimation</h2>
<p>Frequently with PGMs, we will want to compute expectations such as:</p>
\[\mathbb{E}_{x \sim p}[ f(x) ] = \sum_x f(x) p(x)\]
<p>If the function \(f(x)\) doesn’t have a structure that matches the structure of \(p(x)\), we will need to use sampling to approximate the integral because we cannot compute it analytically*1. We can compute an approximation by drawing many samples from \(p(x)\) - in general, this approach is referred to as a Monte Carlo method [1]. (Monte Carlo methods are actually a broad class of strategies - you can use them to do everything from compute the area of a circle to performing inference!) The approximation of the expectation above using Monte Carlo can be written as [1]:</p>
\[\mathbb{E}_{x \sim p}[f(x)] \approx I_T = \frac{1}{T} \sum_{t=1}^T f(x^t)\]
<p>Where T is the total number of samples drawn from \(p(x)\). If we take enough samples, the variance of \(I_T\) can be minimized so that the Monte Carlo approximation, \(\mathbb{E}_{x^1, …, x^T \sim p}[I_T] = \mathbb{E}_{x \sim p}[f(x)]\) after sufficiently many samples have been drawn [1].</p>
<h2 id="markov-chain-monte-carlo">Markov Chain Monte Carlo</h2>
<p>Building on the Monte Carlo idea, we can perform marginal and MAP inference using a sampling approach [1]. This approach is called <strong>Markov chain Monte Carlo</strong> (MCMC). A Markov chain is a time series of random variables \(X_0, X_1, X_2,...\) where each random variable can take \(d\) possible values. Together the \(X_i \in \{1, 2, ... , d\}\) represent the state of a system changing over time. The probability distribution of the first random variable is \(P(X_0)\), and the probability of every state after that time is dependent only on the previous state, that is [1]:</p>
\[P(X_i | X_{i-1})\]
<p>This is very important - each state in a Markov chain only depends on what happened in the immediate prior time step, and the rest of the system’s history does not matter. This <strong>transition probability</strong> distribution is the same at every time step, and that fact is called the <strong>Markov assumption</strong> [1]. We can represent the probability distribution as a matrix [1]:</p>
\[T_{ij} = P(X_{new} = i | X_{prev} = j)\]
<p>So the probability of arriving at each state after \(t\) time steps is [1]:</p>
\[p_t = T^t p_0\]
<p>Where \(p_0\) was the initial probability distribution at the first time step. Over many time steps, the Markov chain will often arrive at a <strong>stationary distribution</strong> defined as \(\pi = \lim_{t \rightarrow \inf} p_t\) (assuming such a limit exists) [1].</p>
<p>The reason that the MCMC algorithm is useful for sampling is that if we run it over many time steps, the transition probability will approach a stationary distribution that is equal to the probability distribution from which we want to sample. At a high level, here’s how the MCMC algorithm works [1]:</p>
<ol>
<li>
<p>Inputs are a transition operator T and an initial state, \(X_0\). The transition operator describes a Markov chain with stationary distribution \(p\). The initial state \(X_0\) is a first set of values for the variables in \(p\).</p>
</li>
<li>
<p>Run the Markov chain from \(X_0\) for \(B\) time steps. This is called the <strong>burn-in</strong> time.</p>
</li>
<li>
<p>Continue to run the Markov chain for \(N\) sampling steps and collect all of the states that it visits - this is the stationary distribution which we are sampling from.</p>
</li>
</ol>
<p>One useful thing that we can do using MCMC is to take the sample with the highest probability (i.e. the one that appears most often) and use it to estimate the mode of the probability distribution. This is essentially the MAP estimate [1].</p>
<h2 id="metropolis-hastings-algorithm">Metropolis-Hastings Algorithm</h2>
<p>The Metropolis-Hastings (MH) algorithm is one way to implement MCMC. The MH approach builds a transition operator, defined as [1]:</p>
\[T(x’ | x) = Q(x’ | x) A(x’ | x)\]
<p>The transition operator is obtained using two pieces [1]:</p>
<ol>
<li>A transition kernel chosen by the user (often we use something simple like a Gaussian):</li>
</ol>
\[Q(x’ | x)\]
<ol>
<li>An acceptance probability for transitions chosen by \(Q\), and written as:</li>
</ol>
\[A(x’ | x) = \min \bigg( 1, \frac{P(x’)Q(x | x’)}{P(x) Q(x’ | x)} \bigg)\]
<p>MH begins with these components, and begins to build a Markov chain by choosing a new point \(x’\) according to \(Q\). Then we decide to accept the new state with some probability \(\alpha\), or keep the current state with probability \(1 - \alpha\). The probability distribution \(P\) is our stationary distribution [1].</p>
<p>The mechanism of the acceptance probability pushes our sampling towards regions of high probability within the distribution. Consider the expression for \(A\) above (and imagine that \(Q\) is just a uniform distribution). The value of \(A\) will be greater than 1 when \(\frac{P(x’)}{P(x)}\) is greater than 1, which happens when \(P(x’) > P(x)\) - that is, the acceptance probability is high when the probability of the new sample is greater than the old one. In this case, we accept the new sample with probability 100%. When the \(Q\) distribution alters this so that we accept a sample that actually results in a lower new probability \(P(x’)\), we only do this a small amount of the time [1].</p>
<p>Regardless of the choice of \(Q\), the MH algorithm always pushes the stationary distribution of the Markov chain towards \(P\). To see this, let us first introduce a concept known as the <strong>detailed balance condition</strong>. The detailed balance condition is a sufficient condition for a stationary distribution, and it states that [1]:</p>
\[\pi(x’) T(x | x’) = \pi(x) T(x’ | x) \text{ } \forall x\]
<p>We can use this to show that the stationary distribution \(\pi(x)\) is equal to \(P(x)\) as follows. When \(A(x’ \text{given} x) < 1\), then:</p>
\[A(x’ | x) = \frac{P(x’)Q(x | x’)}{P(x) Q(x’ | x)}\]
<p>Note that if \(A(x’ \text{given} x) < 1\), then the inverse expression is greater than 1:</p>
\[\frac{P(x)Q(x’ | x)}{P(x’)Q(x|x’)} > 1\]
<p>And so \(A(x \text{given} x’) = 1\) because that is the acceptance probability for this inverse expression. This allows us to incorporate this expression \(A(x \text{given} x’)\) below:</p>
\[P(x’)Q(x | x’)A(x | x’) = P(x)Q(x’ | x)A(x’ | x)\]
<p>And we can say that the transition probability is the product of \(Q\) and \(A\) thus:</p>
\[P(x’) T(x | x’) = P(x) T(x’ | x)\]
<p>This equation exactly matches the detailed balance condition and therefore proves that the MH algorithm sets \(\pi(x) = P(x)\) [1].</p>
<p>MCMC methods are not perfect. As we mentioned in our previous post, there is no way to know for sure when we reach a globally optimal solution. Now that we have seen the MH algorithm, we can say more specifically that it is difficult to know when the burn-in time has drawn enough samples to cause us to reach the stationary distribution. It may be that for complex distributions, we will be stuck sampling within regions of high probability and it will take a very long time to sample from other regions as well to get an accurate representation of the overall distribution [1].</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 This statement confused me a little bit, so I did a little more digging. Apparently, this equation is also known as the <strong>Law of the Unconscious Statistician</strong> [2]. The moniker apparently comes from the fact that many people use this equation while thinking that it is a simple fact, and not a law that was rigorously proved [2]. In other words, statisticians use it without being aware of the rigor of the statement, I guess? Anyway, the point is that when I looked through some examples of how to use and prove this rule as presented by in [3], I think that you would not be able to solve the integral unless the functions \(f(x)\) and \(p(x)\) had a nice mathematical relationship that you could exploit in analytically solving the integral.</p>
<h2 id="references">References</h2>
<p>[1] Kuleshov, V. and Ermon, S. “Sampling methods.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/sampling/">https://ermongroup.github.io/cs228-notes/inference/sampling/</a> Visited 1 Jul 2022.</p>
<p>[2] “Law of the unconscious statistician.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Law_of_the_unconscious_statistician">https://en.wikipedia.org/wiki/Law_of_the_unconscious_statistician</a> Visited 1 Jul 2022.</p>
<p>[3] Rumbos, A. “Probability: Lecture Notes.” 23 Apr 2008. <a href="https://pages.pomona.edu/~ajr04747/Spring2008/Math151/Math151NotesSpring08.pdf">https://pages.pomona.edu/~ajr04747/Spring2008/Math151/Math151NotesSpring08.pdf</a> Visited 1 Jul 2022.</p>In this post we are going to talk about sampling methods for use in performing inference over probabilistic graphical models. Last time, we talked about how sampling methods produce approximate results for inference and are best used on really complex models when it is intractable to find an exact solution in polynomial time [1]. Sampling methods essentially do what they say on the tin: they indirectly draw samples from a distribution and this assortment of samples can be used to solve many different inference problems. For example, marginal probabilities, MAP quantities, and expectations of random variables drawn from a given model can all be computed using sampling. Variational inference is often superior in accuracy, but sampling methods have been around longer and can often succeed where variational methods fail [1]. Let’s dive in to learn more about sampling.Inference with Probabilistic Graphical Models2022-06-30T00:00:00+00:002022-06-30T00:00:00+00:00http://sassafras13.github.io/InferencePGMs<p>In a <a href="https://sassafras13.github.io/GNN/">previous blog post</a>, I wrote about message passing with probabilistic graphical models (PGMs), which is used to perform inference (i.e. calculating a probability). Today, I wanted to explore some other approaches to inference, with a particular focus on algorithms for <strong>approximate inference</strong>. In exact inference, you need to be able to calculate the probabilities for every possible case in your model, and this can quickly become intractable, which is why we use approximate inference methods that avoid this problem. In this post, I will again revisit some of the basic concepts around inference that motivate the need for approximate inference, then I will discuss some specific approximate inference algorithms.</p>
<h2 id="inference-with-pgms">Inference with PGMs</h2>
<p>Inference means that we use a known probabilistic graphical model to answer questions about the system that we have modeled [1]. In PGM theory, we can usually categorize inference questions into two types [1]:</p>
<ol>
<li><strong>Marginal inference:</strong> We compute the probability of one random variable for all values of other random variables in the model. We do this by summing over the other variables in the model except for the variable of interest:</li>
</ol>
\[p(x_1) = \sum_{x_2} … \sum_{x_n} p(x_1, x_2, …, x_n)\]
<ol>
<li><strong>Maximum a posteriori (MAP) inference:</strong> Here we want to know the most likely values of all of the random variables in the PGM. Instead of summing over the random variables, we compute the argmax:</li>
</ol>
\[\text{argmax}_{x_1, …, x_n} p(x_1, …, x_n, y = 1)\]
<p>In this example, \(y\) is some evidence that we have which helps to fix the values of some of the random variables [1]. This blog post is going to be centered around the MAP inference problem, but before we dive into this, it is worth clarifying that we can answer these inference questions with <strong>exact inference</strong> or <strong>approximate inference</strong>. Exact inference is really only possible for smaller, less complex models because we need to be able to calculate the grand partition function, which is a sum over all the possible states of the model. If there are too many states to enumerate explicitly, then we turn to approximate inference instead.</p>
<h2 id="map-vs-mle-revisited">MAP vs. MLE, Revisited</h2>
<p>I have a <a href="https://sassafras13.github.io/MLEvsMAP/">previous blog post</a> where I do a deep dive into how MAP and maximum likelihood estimation (MLE) compare, but I will quickly recap the two ideas here because I want to make clear how MAP uses Bayesian inference while MLE does not*1. Both MAP and MLE are methods for solving the same problem: what probability distribution (described by the parameters of the PGM, in our case) is most likely to give rise to the observed data? In other words, what is the most likely explanation for the data that we are seeing?</p>
<p>The primary difference between MLE and MAP is that MLE uses <em>only</em> the observed data to answer this question, while MAP allows us to incorporate prior knowledge of the system. We may prefer to use the MAP method when we have a small dataset where it would be difficult to find accurate parameters using the data alone.</p>
<p>Let’s look at this first by writing down the MLE in math. We use MLE to find the parameters for the probability distribution (the PGM) that maximize some likelihood function - in other words, these parameters for the probability distribution will maximize the likelihood that the distribution returns the observed data [2]. The maximum likelihood estimate is this set of parameters, which we can write as \(\theta = [\theta_1, …, \theta_k]\). The distribution parameterized by \(\theta\) is expressed as a parametric family [2]:</p>
\[\{ f(\cdot ; \theta) | \theta \in \Theta \}\]
<p>Where \(\Theta\) is the parameter space that we are searching through. The observed data is \(\mathbf{y} = (y_1, …, y_n)\), and the likelihood function is \(\mathcal{L}_n(\theta ; \mathbf{y}) = f_n(\mathbf{y} ; \theta)\) [2]. If the random variables are i.i.d., then we can write \(f_n(\mathbf{y} ; \theta)\) as a product of density function for each random variable, that is [2]:</p>
\[f_n(\mathbf{y} ; \theta) = \prod_{k=1}^n f_k(y_k ; \theta)\]
<p>Our goal is to find the set of parameters, \(\theta\), that maximize the likelihood function, i.e. [2]:</p>
\[\theta^* = \text{argmax}_{\theta \in \Theta} \mathcal{L}_n (\theta ; \mathbf{y})\]
<p>(Note that it can be useful, mathematically, to use the log-likelihood version of this definition, \(\mathcal{l}(\theta ; \mathbf{y}) = \ln \mathcal{L}_n (\theta ; \mathbf{y})\) [2].)</p>
<p>Now that we have the MLE in hand, let’s see how the MAP compares. Remember, the MAP is doing the same thing as the MLE but it also includes a prior distribution. Specifically, the MAP assumes that the parameters \(\theta\) have some prior distribution, \(g\) [3]. Then we can write the posterior distribution of \(\theta\) as [3]:</p>
\[f_n(\theta ; \mathbf{y}) = \frac{f_n( \mathbf{y} | \theta) g(\theta)}{\int_{\Theta} f( \mathbf{y} | \mathcal{v}) g(\mathcal{v}) d \mathcal{v}}\]
<p>The MAP method assumes that \(\theta\) can be computed as the argmax (i.e. the mode) of the posterior distribution [3]:</p>
\[\theta^* = \text{argmax}_{\theta} f_n(\theta | \mathbf{y})\]
\[\theta^* = \text{argmax}_{\theta} \frac{f_n( \mathbf{y} | \theta) g(\theta)}{\int_{\Theta} f( \mathbf{y} | \mathcal{v}) g(\mathcal{v}) d \mathcal{v}}\]
<p>Here, we can disregard the denominator (the <strong>marginal likelihood</strong>) in the Bayesian expression for the posterior distribution. This is because it is always positive, and does not depend on \(\theta\) and so will not affect the value of \(\theta^*\). This allows us to write the MAP estimate more simply as [3]:</p>
\[\theta^* = \text{argmax}_{\theta} f_n( \mathbf{y} | \theta) g(\theta)\]
<p>In this form, it is easier to see that the MAP and MLE methods return the same value for \(\theta^*\) when \(g(\theta)\) is a uniform distribution. Let me also write the MAP in terms of a UGM just to connect this back to our previous discussion about PGMs [4]. First of all, we can write the joint probability represented by a UGM as [4]:</p>
\[p(x_1, …, x_n) = \frac{1}{\mathcal{Z}} \prod_{c \in \mathcal{C}} \phi_c (x_c)\]
<p>And the MAP is looking for the values of \(x\) that maximize the log-likelihood, i.e. [4]:</p>
\[\max_x \log p(x) = \max_x \sum_c \log \phi_c(x_c) - \log(\mathcal{Z})\]
<p>Okay, so now we have a good understanding of how we can use the MLE and MAP methods for inference, and how the MAP method relies on having a prior distribution over \(\theta\), which we can obtain from our PGM [3]. Now that we understand MAP inference, we can examine a few different methods for computing the estimate when we must do approximate inference.</p>
<h2 id="approaches-to-solving-the-approximate-inference-problem">Approaches to Solving the Approximate Inference Problem</h2>
<p>The methods available for solving the approximate inference problem can be broadly divided into two categories: <strong>sampling</strong> and <strong>variational inference</strong>. I will write detailed explanations for some methods within these categories in future posts, but for now let me give an introduction to each of them.</p>
<p>Sampling methods - which are an older class of methods than variational ones - can be used to do both marginal and MAP inference, as well as computing expectations [5]. The advantage to using sampling algorithms like the Metropolis-Hastings algorithm is that we can compute a probability \(p(x)\) if we know a function \(f(x)\) that is <em>proportional</em> to the probability density \(p\) - since \(f(x)\) only has to be proportional to \(p(x)\), we do not need to compute the grand partition function, \(\mathcal{Z}\), which is intractable [6]. We sample from \(f(x)\) many times and accept highly probable samples, thereby indicating what the expectation and the MAP for this probability \(p(x)\) should be [6].</p>
<p>Sampling methods have been in use for a long time, but they have some significant downsides which variational methods can overcome. For example, although sampling methods are guaranteed to eventually find a globally optimal solution, it is hard, in practice, to know when they have approached a good solution. Moreover, choosing the sampling technique for some algorithms can have a large effect on the process and is therefore something of an art form [7].</p>
<p>Variational inference*2, by contrast, presents the inference problem as an optimization problem. Specifically, if our goal is to compute the MAP for an intractable probability distribution, \(p\), then our goal is to find a tractactable probability \(q \in \mathcal{Q}\) that is the most similar distribution to \(p\). Then we use \(q\) to perform inference tasks instead of \(p\). The benefit of this approach is that we have a deep body of knowledge supporting optimization problem solving methods. This brings its own pros and cons to using variational methods. For instance, variational methods are not guaranteed to find the globally optimal solution. However, it is possible to know if they have converged (unlike sampling methods) and they tend to scale better with modern hardware like GPUs [7].</p>
<p>I will end this post here, but in my next couple of posts, I will write about sampling and variational inference methods in more detail to better understand how they work.</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 I should point out that this statement is not strictly true. Since MLE can be seen as a special case of MAP, you could say that MLE is a form of Bayesian inference where we assume the priors of the parameters are uniformly distributed [2].</p>
<p>*2 Variational inference is called that because it is derived from the calculus of variations, which optimizes functionals (functions of functions) [7].</p>
<h2 id="references">References</h2>
<p>[1] Kuleshov, V. and Ermon, S. “Introduction.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/preliminaries/introduction/">https://ermongroup.github.io/cs228-notes/preliminaries/introduction/</a> Visited 30 Jun 2022.</p>
<p>[2] “Maximum likelihood estimation.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Maximum_likelihood_estimation">https://en.wikipedia.org/wiki/Maximum_likelihood_estimation</a> Visited 30 Jun 2022.</p>
<p>[3] “Maximum a posteriori estimation.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation">https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation</a> Visited 30 Jun 2022.</p>
<p>[4] Kuleshov, V. and Ermon, S. “MAP inference.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/map/">https://ermongroup.github.io/cs228-notes/inference/map/</a></p>
<p>[5] Kuleshov, V. and Ermon, S. “Sampling methods.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/sampling/">https://ermongroup.github.io/cs228-notes/inference/sampling/</a> Visited 30 Jun 2022.</p>
<p>[6] “Metropolis-Hastings algorithm.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm">https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm</a> Visited 30 Jun 2022.</p>
<p>[7] Kuleshov, V. and Ermon, S. “Variational inference.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/variational/">https://ermongroup.github.io/cs228-notes/inference/variational/</a> Visited 30 Jun 2022.</p>In a previous blog post, I wrote about message passing with probabilistic graphical models (PGMs), which is used to perform inference (i.e. calculating a probability). Today, I wanted to explore some other approaches to inference, with a particular focus on algorithms for approximate inference. In exact inference, you need to be able to calculate the probabilities for every possible case in your model, and this can quickly become intractable, which is why we use approximate inference methods that avoid this problem. In this post, I will again revisit some of the basic concepts around inference that motivate the need for approximate inference, then I will discuss some specific approximate inference algorithms.Understanding Message Passing in GNNs2022-06-19T00:00:00+00:002022-06-19T00:00:00+00:00http://sassafras13.github.io/GNN<p>I have written about graph neural networks (GNNs) a lot, but one thing I am still trying to understand is how message passing works when you add neural networks to the mix. <strong>Message passing</strong>, put simply, refers to sharing information between nodes in a graph along the edges that connect them. I first encountered it in discussions about <strong>variable elimination</strong> for performing <strong>inference</strong> on a probabilistic graphical model (PGM) - we will demystify that sentence below - but I still wanted to make the connection for myself between doing this analytically on a small graph to doing it computationally with a much larger graph and using neural networks. Let’s dive in!</p>
<h2 id="graphical-models-can-only-do-2-things">Graphical Models Can Only Do 2 Things</h2>
<p>Well, okay, I mean, maybe they can do lots of things but we can broadly categorize the tasks that you can perform with graphical models into two tasks: learning and inference. <strong>Learning</strong> refers to finding the graphical model, \(M\), given some data, \(\mathcal{D}\). Conversely, <strong>inference</strong> refers to calculating a probability \(P_M(X | Y)\) using the model \(M\) [2]. Message passing, the star of today’s blog post, emerges as a technique for performing inference.</p>
<p>Let’s pause and review the basics of probability and what kinds of probabilistic inference we can perform. There are two forms that we commonly try to solve with PGMs: <strong>marginal inference</strong> and <strong>maximum a posteriori (MAP) inference</strong> [3]. Marginal probability asks what the probability of observing a single random variable is, given any outcome for the other random variables. In other words, we sum over all of the other random variables in the model except for the ones we’re interested in.*1 We can write this as [3]:</p>
\[P(y = 1) = \sum_{x_1} \sum_{x_2} … \sum_{x_n} p(y = 1, x_1, x_2, …, x_n)\]
<p>Alternatively, the MAP inference asks what is the most likely value for all the random variables in our model [3]:</p>
\[\max_{x_1, …, x_n} p(y = 1, x_1, … , x_n)\]
<p>So inference involves calculating some probability given the model that we have. As we’ll see below, this is very straightforward if the model is small - in this case, we can compute the <strong>exact inference</strong> value for our model. However, if the model is much larger and more complex, then we may find that calculating either of these forms of inference becomes intractable. In that case, we must turn to using various algorithms for <strong>approximate inference</strong> since the exact value is NP-hard [3]. In this post, we’ll confine ourselves to discussing exact inference because that is all we need to understand to get to message passing, but I wanted to mention this here so that we were aware there are other approaches for when graphs get larger.</p>
<h2 id="variable-elimination-as-inference">Variable Elimination as Inference</h2>
<p>Here we will see how we can perform exact inference by computing the marginal probability for a small directed chain graph as shown in Figure 1. Our graph represents the joint probability over all the random variables in all of the nodes of the graph. We will see that we can take advantage of the graphical structure of this model to make computing the marginal probability very efficient.</p>
<p><img src="/images/2022-06-19-GNN-fig1.png" alt="Fig 1" title="Figure 1" />
Figure 1</p>
<p>Note that the meaning of these factors varies slightly depending on which type of graphical model you are working with. For directed graphs, the factors are conditional probabilities with respect to the node’s parents, i.e. \(\{ P(X_i \text{ given } X_{PA_i}) \}\). For undirected graphs, the factors are called <strong>clique potentials</strong> and are written as \(\phi_c (X_c)\). Clique potentials become probabilities after they are normalized [1]. If we wanted to compute the marginal probability of \(P(X_5)\), we would sum over all of the other variables, like so [2]:</p>
\[P(X_5) = \sum_{X_1} \sum_{X_2} \sum_{X_3} \sum_{X_4} P(X_1, X_2, X_3, X_4, X_5)\]
<p>The computational cost of doing this naively (i.e. just by summing over all of the other random variables in the graph) is \(\mathcal{O}(k^n)\), where \(k\) is the number of possible values of each variable, and \(n\) is the number of random variables [2,3]. The reason why variable elimination is so powerful is that it greatly speeds up the computation of this marginal probability. Graphical models are computationally efficient because they allow us to write all of the possible probabilities more compactly than in a tabular format. However, there is a limitation to this: when we need to sum up (or marginalize) over certain random variables in the joint distribution, that can still be prohibitively expensive even though we are working with a graphical model that is supposed to be efficient. Variable elimination helps us to speed up the process of marginalizing over random variables in our model by side-stepping the need to compute the sum as written above [1]. We are going to see how this works by leveraging the model that we have (Figure 1) to write the joint probability distribution for this graph as [2]:</p>
\[P(X_1, X_2, X_3, X_4, X_5) = P(X_1) P(X_2 | X_1) P(X_3 | X_2) P(X_4 | X_3) P(X_5 | X_4)\]
<p>And now we can use this to rewrite the marginal probability for \(P(X_5)\) as [2]:</p>
\[P(X_5) = \sum_{X_1} \sum_{X_2} \sum_{X_3} \sum_{X_4} P(X_1) P(X_2 | X_1) P(X_3 | X_2) P(X_4 | X_3) P(X_5 | X_4)\]
<p>Moreover, we can start to move the summations further into the expression, like so [2]:</p>
\[P(X_5) = \sum_{X_4} P(X_5 | X_4) \sum_{X_3} P(X_4 | X_3) \sum_{X_2} P(X_3 | X_2) \sum_{X_1} P(X_1) P(X_2 | X_1)\]
<p>Observe that the last term, \(\sum_{X_1} P(X_1) P(X_2 \text{ given } X_1)\), can be simplified to \(P(X_2)\), to give us [2]:</p>
\[P(X_5) = \sum_{X_4} P(X_5 | X_4) \sum_{X_3} P(X_4 | X_3) \sum_{X_2} P(X_3 | X_2) P(X_2)\]
<p>We can continue this operation to marginalize out the other probabilities - we also call this <strong>eliminating</strong> variables, which gives the algorithm its name [2,3]:</p>
\[P(X_5) = \sum_{X_4} P(X_5 | X_4) \sum_{X_3} P(X_4 | X_3) P(X_3)\]
\[P(X_5) = \sum_{X_4} P(X_5 | X_4) P(X_4)\]
<p>Each time we eliminate a variable, we must perform a computation that is of \(\mathcal{O}(k^2)\) time and then we do this for \(n\) variables, giving us a total computational complexity of just \(\mathcal{O}(nk^2)\), which is clearly less than the naive solution of \(\mathcal{O}(k^n)\) [2,3]. Also, it’s worth mentioning that If you are familiar with computer science, you may have noticed that variable elimination is a form of dynamic programming [2,3].</p>
<p>Now that we have the basic idea of variable elimination, let’s see where the idea of message passing comes into play.</p>
<h2 id="message-passing">Message Passing</h2>
<p>The key insight here is that each variable elimination is a message that is passed from one node to another. Put another way, each time we perform the operation \(P(X_2) = \sum_{X_1} P(X_1) P(X_2 \text{ given } X_1)\), this is a message that is passed from node \(X_1\) to \(X_2\) [3]. Note that the literature often refers to message-passing as <strong>belief propagation</strong> [2,5]. This insight is valuable because it means that we can reuse messages to perform many different queries very cheaply. For example, if we wanted to compute the marginal probability for a different random variable in the model above, we could rewrite the marginalization process and reuse many of the messages that we had already computed to obtain \(P(X_5)\). This saves us a lot of computation time that we would otherwise have to spend recomputing the same messages [2,5].</p>
<p>This idea unlocks a wide range of algorithms that focus on efficiently organizing the calculations required to marginalize the probability, given that we can reuse messages and that each variable elimination “collapses” part of the graph. I will not go into detail here because these algorithms are also generally for exact inference on small graphs and I’m more concerned here with big graphs and neural networks. But let me just say that generally speaking, all message passing algorithms have to follow the <strong>message passing protocol</strong> which states that a node can only send a message to its neighbors when it has finished getting all the messages from its other neighbors [2]. This idea of order is critical to message passing algorithms.</p>
<p>Battaglia et al. wrote a beautiful position piece on the value of graphical models in deep learning, and they describe message passing as a means of performing reasoning, much like humans do [6]. They point out that it allows for local information sharing (i.e. between nodes that are connected in the graph) and that it can be, to some extent, parallelized, which makes computations more efficient (this goes back to our mentions of various algorithms that draw on ideas from dynamic programming to perform efficient inference on a graph) [6].</p>
<p>This paper by Battaglia et al. also lays out a general framework for graph net (GN) blocks that ensures basic functions are performed which are common to all graphical frameworks, but also allows for flexibility depending on the application. A GN block includes edges, \(E\) with attributes \(\mathbf{e}\), nodes (or vertices), \(V\) with attributes \(\mathbf{v}\), and global properties, \(\mathbf{u}\). For a directed graph, we also have sender (s) and receiver (r) nodes at the source and sink of a directed edge. Generally, a GN block must perform three update functions (denoted as \(\phi\)) and three aggregation functions, denoted as \(\rho\). These can be written as [6]:</p>
\[\mathbf{e}_k’ = \phi^e(\mathbf{e}_k, \mathbf{v}_{r_k}, \mathbf{v}_{s_k}, \mathbf{u})\]
\[\mathbf{v}_i’ = \phi^v (\mathbf{\bar{e}}_i’, \mathbf{v}_i, \mathbf{u})\]
\[\mathbf{u}’ = \phi^u (\mathbf{\bar{e}}’, \mathbf{\bar{v}}’, \mathbf{u})\]
\[\mathbf{\bar{e}}_i’ = \rho^{e \rightarrow v} (E_i’)\]
\[\mathbf{\bar{e}}’ = \rho^{e \rightarrow u} (E’)\]
\[\mathbf{\bar{v}}’ = \rho^{v \rightarrow u}(V’)\]
<p>And we can also visualize these a couple of different ways as shown in Figure 2 below.</p>
<p><img src="/images/2022-06-19-GNN-fig2.png" alt="Fig 2" title="Figure 2" />
Figure 2 - Source [6]</p>
<p>Now, according to Battaglia et al., authors have implemented forms of GNs that use neural networks to serve as specific functions. For example, Sanchez-Gonzalez et al. [7] used neural networks as the update functions, \(\phi\), and elementwise summation operations for the aggregation functions, \(\rho\) [6]. Battaglia et al. also note that MLPs are well suited to vector-based attributes (\(\mathbf{v}\) and \(\mathbf{e}\)) while CNNs may be a better choice if the attributes are images, highlighting the flexibility of the GN block itself [6].</p>
<h2 id="why-can-i-use-a-neural-network-for-that">Why Can I Use a Neural Network for That?</h2>
<p>My question, however, is <em>why</em> can we use neural networks to perform message passing, specifically the update functions? To get a better handle on this, I turned to some posts by Daniele Grattarola which were extremely helpful [8,9]. We’ll dive into his explanation for how we do message passing with graphs in a minute, but first let me introduce some new (to me) terminology. Grattarola uses the term <strong>reference operator</strong>, \(R\) to refer to any matrix that describes the graph and has the same sparsity pattern as the adjacency matrix, \(A\). So, for example, the normalized adjacency matrix and the Laplacian matrix both have the same sparsity pattern (ignoring the diagonal). That means that all three of these matrices have non-zero entries where there are edges in the graph, and zeroes everywhere else (with the exception that some will include ones along the diagonal for self-loops and others omit them) [8].</p>
<p>Reference operators are critical to understanding why neural networks can represent a message passing operation. First, they are, by definition, operators, which means that they transform the <strong>graph signal</strong> (often this is the node attribute matrix, \(X\) - we saw this terminology in our discussion of <a href="">graph convolution</a>) and output a new graph signal. Second, when we multiply \(R\) with a graph signal, we are essentially computing the weighted sum of each node’s neighboring nodes - which <em>passes messages</em> between connected nodes [8]. Consider this simple example, where we have a node \(X_1\) connected to nodes \(X_2, X_3, X_4\) (there might be other nodes in the graph but if \(X_1\) is not connected to them then the reference operator would show 0’s everywhere that there could be a connection between \(X_1\) and, say, \(X_5\)). The reference operator is multiplied with the graph signal like so [8]:</p>
\[(\mathbf{RX})_1 = \mathbf{r}_{12} \cdot \mathbf{x}_2 + \mathbf{r}_{13} \cdot \mathbf{x}_3 + \mathbf{r}_{14} \cdot \mathbf{x}_4\]
<p>This expression is nothing more than a weighted sum of the neighbors of \(X_1\), which could serve to <em>approximate</em> the marginal probability of \(X_1\) based on the random variables that it is conditioned on, namely \(X_2, X_3\) and \(X_4\). This, I believe, is the crux of why it is valid to use a simple matrix multiplication operation (and by extension, an MLP) to perform updates of all the nodes in a graph. By the very nature of the matrix multiplication with this reference operator, only connected nodes will pass their information to their neighbors, otherwise the weight for their contributions is simply 0.</p>
<p>Now, in addition to reference operators, we may want to also transform our node attributes. This can be useful if we want to extract useful features from our graphs that we think may help us achieve some learning task (think back to MolGAN which used graph convolution to look for features of small molecules that made them well suited to be water soluble, for example). To extract features from our graph, we can use <strong>filters</strong>, \(\Theta\), just as we would in a classic convolutional neural network (you could also think of these as simply transforming the graph into some latent representation that helps the GNN learn to perform the desired function efficiently). We can write the transformations \(R\) and \(\Theta\) operating on a graph signal as follows [8]:</p>
\[\mathbf{X}’ = \mathbf{RX\Theta}\]
<p>Grattarola credits the 2017 paper by Gilmer et al. [10] for introducing the specific message-passing neural network, and we’ll see in a moment that the framework dovetails well with the basic concept of message passing as a form of variable elimination that we explored above [9]. We can summarize message passing in GNNs as requiring 3 steps [9, 10]:</p>
<ol>
<li>
<p>Each node computes a <strong>message</strong> for its neighbors. The message function can be written as \(m(\text{node}, \text{neighbor}, \text{edge})\).</p>
</li>
<li>
<p>The messages are sent and each node <strong>aggregates</strong> all of its messages. Typically the aggregation is a simple sum or average.</p>
</li>
<li>
<p>The nodes <strong>update</strong> their attributes given the current values of those attributes and the aggregated information from the messages.</p>
</li>
</ol>
<p>Note that these steps happen simultaneously for all nodes in the graph so in one message passing step, all the nodes of the graph are updated. So what does this look like in terms of our simple GNN written above? Let’s break it down [9]:</p>
<ol>
<li>
<p><strong>Message:</strong> Each node, \(i\), receives a message from each neighbor, where each message is the transformed graph signal, \(\mathbf{\Theta}^T \mathbf{x}_j\), where each neighbor is \(j \in \mathcal{N}(i)\).</p>
</li>
<li>
<p><strong>Aggregate:</strong> The messages are aggregated via the weighted sum, which is simply performed as matrix multiplication with the reference operator, \(\mathbf{R}\).</p>
</li>
<li>
<p><strong>Update:</strong> Now each node replaces its attributes with the new ones it computed in the previous step. If the diagonal of the reference operator is non-zero, then we will combine the new messages with a set of self-loop messages.</p>
</li>
</ol>
<p>This simple approach is the equivalent of the message passing that we saw previously in the discussion about exact inference. The nice thing is that as long as each step (message, aggregate and update) is differentiable, then our model can learn optimal values for these functions to accurately complete the task at hand [9].</p>
<h2 id="conclusion">Conclusion</h2>
<p>I hope you found this exploration a little bit helpful in understanding what GNNs are doing and why it works. This post was a drill down to answer a particular question I had, but if any of the topics I mentioned here were interesting to you, I would strongly encourage you to look at my references or other sources to learn more - I love this topic!</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 I always wonder why this kind of probability is called a “marginal probability” and according to Wikipedia it is because if you were doing this calculation for, say, a tabulated set of probabilities, to calculate the marginal probability you would have to calculate the sums over the other variables in the <strong>margins</strong> of the table [4]. Like okay, fine, I guess, if everything we did was still pencil and paper? I’m sure there’s a catchier name out there somewhere, that’s all.</p>
<h2 id="references">References</h2>
<p>[1] Koller, D. and Friedman, N. Probabilistic Graphical Models: Principles and Techniques. MIT Press, 2009.</p>
<p>[2] Xing, E. “Lecture 4: Exact Inference.” 10-708 Probabilistic Graphical Models, Spring 2017. Carnegie Mellon University. <a href="https://www.cs.cmu.edu/~epxing/Class/10708-17/notes-17/10708-scribe-lecture4.pdf">https://www.cs.cmu.edu/~epxing/Class/10708-17/notes-17/10708-scribe-lecture4.pdf</a> Visited 18 Jun 2022.</p>
<p>[3] Kuleshov, V. and Ermon, S. “Variable Elimination.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/ve/">https://ermongroup.github.io/cs228-notes/inference/ve/</a> Visited 18 Jun 2022.</p>
<p>[4] “Marginal distribution.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Marginal_distribution">https://en.wikipedia.org/wiki/Marginal_distribution</a> Visited 18 Jun 2022.</p>
<p>[5] Kuleshov, V. and Ermon, S. “Junction Tree Algorithm.” CS228 Probabilistic Graphical Models, 2022. Stanford University. <a href="https://ermongroup.github.io/cs228-notes/inference/jt/">https://ermongroup.github.io/cs228-notes/inference/jt/</a> Visited 18 Jun 2022.</p>
<p>[6] Battaglia, P. W., Hamrick, J. B., Bapst, V., Sanchez-Gonzalez, A., Zambaldi, V., Malinowski, M., Tacchetti, A., Raposo, D., Santoro, A., Faulkner, R., Gulcehre, C., Song, F., Ballard, A., Gilmer, J., Dahl, G., Vaswani, A., Allen, K., Nash, C., Langston, V., … Pascanu, R. (2018). Relational inductive biases, deep learning, and graph networks. ArXiv Preprint ArXiv: 1806.01261. http://arxiv.org/abs/1806.01261</p>
<p>[7] Sanchez-Gonzalez, A., Heess, N., Springenberg, J. T., Merel, J., Riedmiller, M., Hadsell, R., & Battaglia, P. (2018). Graph Networks as Learnable Physics Engines for Inference and Control. Proceedings of the 35th International Conference on Machine Learning, 80, 4470–4479. http://proceedings.mlr.press/v80/sanchez-gonzalez18a.html</p>
<p>[8] Grattarola, D. “A practical introduction to GNNs - Part 1.” 3 Mar 2021. <a href="https://danielegrattarola.github.io/posts/2021-03-03/gnn-lecture-part-1.html">https://danielegrattarola.github.io/posts/2021-03-03/gnn-lecture-part-1.html</a> Visited 18 Jun 2022.</p>
<p>[9] Grattarola, D. “A practical introduction to GNNs - Part 2.” 12 Mar 2021. <a href="https://danielegrattarola.github.io/posts/2021-03-12/gnn-lecture-part-2.html">https://danielegrattarola.github.io/posts/2021-03-12/gnn-lecture-part-2.html</a> Visited 18 Jun 2022.</p>
<p>[10] Gilmer, J., Schoenholz, S. S., Riley, P. F., Vinyals, O., & Dahl, G. E. (2017). Neural Message Passing for Quantum Chemistry. 34th International Conference on Machine Learning, ICML 2017, 3, 2053–2070. http://arxiv.org/abs/1704.01212</p>I have written about graph neural networks (GNNs) a lot, but one thing I am still trying to understand is how message passing works when you add neural networks to the mix. Message passing, put simply, refers to sharing information between nodes in a graph along the edges that connect them. I first encountered it in discussions about variable elimination for performing inference on a probabilistic graphical model (PGM) - we will demystify that sentence below - but I still wanted to make the connection for myself between doing this analytically on a small graph to doing it computationally with a much larger graph and using neural networks. Let’s dive in!Bayesian Optimization for Beginners2022-06-04T00:00:00+00:002022-06-04T00:00:00+00:00http://sassafras13.github.io/BayesianOptimization<p>I am interested in using an optimization technique called <strong>Bayesian optimization</strong> in a current research project, so I wanted to take this opportunity to write a couple of blog posts to describe how this algorithm works. Generally speaking, Bayesian optimization is an appropriate tool to use when you are trying to optimize a function that you do not have an analytical expression for (i.e. that is a black-box function) and when it is expensive to evaluate that function [1]. In this post, I will describe in more detail when Bayesian optimization is useful, and how it works from a mathematical standpoint. I may write an accompanying blog post that dives into some of the relevant mathematical tools more deeply, as well. Let’s get started!</p>
<h2 id="introduction-to-bayesian-optimization">Introduction to Bayesian Optimization</h2>
<p>As I mentioned in the introduction, Bayesian optimization is a good tool to use when you want to optimize a function \(f(x)\) (which we call the <strong>objective function</strong>) and it has the following two properties [1]:</p>
<ul>
<li>
<p>The analytical expression for \(f(x)\) is unknown (it is a black box). In this situation, we cannot use other techniques like finding the optimal solution analytically because we don’t have the expression to work with.</p>
</li>
<li>
<p>The cost of finding \(f(x)\) for a given value of \(x\) is very high. This is often true in experimental settings, where \(f(x)\) refers to a physical process that we are studying, or when we are trying to optimize the hyperparameters of a neural network, and each time we choose a set of hyperparameters we have to completely re-train the network. If this was not true, we could use simpler approaches like grid search to find the optimal value of \(x\), because each evaluation of \(f(x)\) would be cheap.</p>
</li>
</ul>
<p>Let’s assume that we find ourselves in a situation where \(f(x)\) meets both of these conditions. In this case, we can use Bayesian optimization to find \(x^{*}\), the globally optimal value of \(x\), while minimizing the number of times that we query \(f(x)\). In other words, our objective is [2]:</p>
\[\max_{x} f(x)\]
<p>The Bayesian part of Bayesian optimization comes from the fact that we will incorporate samples drawn from the objective function into our prediction for the optimal value of \(x\). As we will see in a moment, we will define a <strong>prior</strong> (i.e. a guess) for \(x^*\) and as we draw samples from the objective function, we will update the prior to define a <strong>posterior</strong> that includes our samples to give a more accurate guess of what \(x^*\) is. Note that, even though sampling the objective function is expensive, we will have to draw samples during this process, but we will try to minimize the number of samples that we need to draw before we can identify \(x^*\) [1].</p>
<p>I will give an intuitive explanation of how Bayesian optimization works and then present the mathematics in the next section. Bayesian optimization uses two functions to guide the optimization process: a <strong>surrogate function</strong> and an <strong>acquisition function</strong>. The surrogate function, \(g(\cdot)\), addresses the problem described above where we have no analytical expression for \(f(x)\) - the surrogate function acts as our best guess of the form of \(f(x)\) given the current information. The acquisition function, \(u(\cdot)\), guides our choice of the next sample of \(x\) to sample from the objective function. The acquisition function must balance exploration and exploitation to find the globally optimal value of \(x\) as quickly as possible. Note that the acquisition function uses the surrogate function as we will see in the next section [1].</p>
<p>Once we have defined the surrogate and acquisition functions, we can use them in an iterative process to optimize the objective function as follows [1]:</p>
<ol>
<li>
<p>Iterate for \(t = 1, 2, …, T\) steps for sampled points, \((\mathbf{x}, y)\), that are added to the set \(\mathcal{D}_{1:t-1}\).</p>
</li>
<li>
<p>Select the next point to sample, \(\mathbf{x}_t\), by finding the argmax of the acquisition function, i.e.:</p>
</li>
</ol>
\[\mathbf{x}_t = \text{argmax}_{\mathbf{x}} u(\mathbf{x} | \mathcal{D}_{1:t-1})\]
<ol>
<li>
<p>Sample the objective function at this point: \(y_t = f(\mathbf{x}_t)\). Add this sample to the set, \(\mathcal{D}_{1:t} = \{\mathcal{D}_{1:t-1}, (\mathbf{x}_t, y_t)\}\).</p>
</li>
<li>
<p>Update the surrogate function, \(g(\cdot)\) with the newly sampled point \((\mathbf{x}_t, y_t)\).</p>
</li>
</ol>
<p>Now let’s dive into the mathematics of this approach in more detail.</p>
<h2 id="the-math-behind-the-surrogate-and-acquisition-functions">The Math Behind the Surrogate and Acquisition Functions</h2>
<p>The surrogate function’s purpose is to represent our best guess of the analytic form of the objective function, \(f(x)\). One commonly used form for the surrogate function is a <strong>Gaussian process</strong> [1,3]. I will write a separate blog post to describe the mathematics behind Gaussian processes (GPs) because they are really cool and deserve more attention. For the purposes of this discussion, let me just briefly say that GPs represent a collection of functions, and each function has some probability that it is the best fit to the observed data [4]. A GP will represent the process as a joint probability over a set of random variables, where each random variable corresponds to one data point, and the set of points follows a multivariate normal distribution [5]. The benefit of using a Gaussian (or normal) representation is that it has many useful properties that make it easier to update the probability distribution using Bayes’ rule. As I said, GPs are really interesting tools so please look for another post coming soon that dives into their mechanics in much more detail*1.</p>
<p>So our best guess at the objective function is represented by a GP, but what about the acquisition function? There are multiple approaches to writing the acquisition function, including probability of improvement, expected improvement, and Thompson sampling [3]. However, I will just focus on expected improvement (EI) for this post, and we can assume that other approaches will fit into the Bayesian optimization algorithm in a similar fashion.</p>
<p>The EI approach balances exploitation and exploration as it recommends the next <strong>query point</strong>, \(\mathbf{x}_t\). Specifically, the EI approach will select a query point either because it is bigger than the best value we have seen so far (exploitation) or because that point is in a region of high uncertainty (exploration). The expression for EI can be written as [3]:</p>
\[\mathbf{x}_{t+1} = \text{argmax}_{\mathbf{x}} \mathbb{E} \left( \max \{ g_{t+1}(\mathbf{x}) - f(\mathbf{x}^+), 0 \} | \mathcal{D}_{1:t} \right)\]
<p>Here, we are choosing the next query point, \(\mathbb{x}_{t+1}\) as the value that maximizes the surrogate function, \(g_{t+1}(\mathbf{x})\), compared to the best value that we’ve seen so far, represented by \(f(\mathbf{x}^+)\). This choice is conditioned on all of the data we have to date, \(\mathcal{D}_{1:t}\) [3].</p>
<p>Since we know the specific form of the surrogate function, \(g(\cdot)\), is as a GP, we can write an analytical expression for the equation above as follows [3]:</p>
\[EI(\mathbf{x}) = \begin{cases} (\mu_t(\mathbf{x}) - f(\mathbf{x}^+) - \zeta)\Phi(Z) + \sigma_t(\mathbf{x})\phi(Z) & \text{if $\sigma_t(\mathbf{x}) > 0$} \\ 0 & \text{if $\sigma_t(\mathbf{x}) = 0$} \end{cases}\]
<p>Where \(\mu_t\) and \(\sigma_t\) are the mean and variance of the GP, and \(\Phi(Z)\) and \(\phi(Z)\) are the CDF and PDF of the GP, respectively. The first term represents the exploitation (because it drives the selection of the query point that maximizes the surrogate function) while the second term represents the exploration (because it favors selecting the query point that is in the area of greatest uncertainty, as represented by the variance, \(\sigma_t\)). The parameter \(\zeta\) controls the trade-off between these two terms representing exploration and exploitation [1,3]. Finally, \(Z\) is [1,3]:</p>
\[Z = \begin{cases} \frac{\mu_t(\mathbf{x}) - f(\mathbf{x}^+) - \zeta}{\sigma_t(\mathbf{x})} & \text{if $\sigma_t(\mathbf{x}) > 0$} \\ 0 & \text{if $\sigma_t(\mathbf{x}) = 0$} \end{cases}\]
<p>This math is everything that is required to implement Bayesian optimization. Krasser presents a beautiful Colab notebook that implements the algorithm from scratch and plots the results in a highly intuitive manner. I recreate part of the key figure below but I would recommend exploring his <a href="http://krasserm.github.io/2018/03/21/bayesian-optimization/">blog post</a> and accompanying notebook to get a better understanding of how the Bayesian optimization algorithm can unfold on an example system [1].</p>
<p><img src="/images/2022-06-04-BayesianOptimization-fig1.png" alt="Fig 1" title="Figure 1" />
Figure [1] - Source [1]</p>
<h2 id="conclusion">Conclusion</h2>
<p>In this post we discussed how Bayesian optimization works and when one would want to use it. In a follow-up post, I will dig more deeply into one of the mathematical tools we used in this algorithm, Gaussian processes. Thanks for reading and stay tuned!</p>
<h2 id="footnotes">Footnotes</h2>
<p>*1 Or just read <a href="https://distill.pub/2019/visual-exploration-gaussian-processes/">this</a> amazing Distill article by Görtler, et al.</p>
<h2 id="references">References</h2>
<p>[1] Krasser, M. “Bayesian optimization.” 21 Mar 2018. <a href="http://krasserm.github.io/2018/03/21/bayesian-optimization/">http://krasserm.github.io/2018/03/21/bayesian-optimization/</a> Visited 4 Jun 2022.</p>
<p>[2] “Bayesian optimization.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Bayesian_optimization">https://en.wikipedia.org/wiki/Bayesian_optimization</a> Visited 4 Jun 2022.</p>
<p>[3] Agnihotri & Batra, “Exploring Bayesian Optimization”, Distill, 2020. <a href="https://distill.pub/2020/bayesian-optimization/">https://distill.pub/2020/bayesian-optimization/</a> Visited 4 Jun 2022.</p>
<p>[4] Görtler, et al., “A Visual Exploration of Gaussian Processes”, Distill, 2019. <a href="https://distill.pub/2019/visual-exploration-gaussian-processes/">https://distill.pub/2019/visual-exploration-gaussian-processes/</a> Visited 4 Jun 2022.</p>
<p>[5] “Gaussian process.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Gaussian_process">https://en.wikipedia.org/wiki/Gaussian_process</a> Visited 4 Jun 2022.</p>I am interested in using an optimization technique called Bayesian optimization in a current research project, so I wanted to take this opportunity to write a couple of blog posts to describe how this algorithm works. Generally speaking, Bayesian optimization is an appropriate tool to use when you are trying to optimize a function that you do not have an analytical expression for (i.e. that is a black-box function) and when it is expensive to evaluate that function [1]. In this post, I will describe in more detail when Bayesian optimization is useful, and how it works from a mathematical standpoint. I may write an accompanying blog post that dives into some of the relevant mathematical tools more deeply, as well. Let’s get started!A Brief Survey of Machine Learning Regression Models2022-06-04T00:00:00+00:002022-06-04T00:00:00+00:00http://sassafras13.github.io/RegressionModels<p>I want to use this post to survey some common machine learning regression models to better understand the range of tools available and when we would choose certain models. Regression refers to finding a function that can predict a continuous output for a given input, for example predicting the price of a house given knowledge of its square footage and number of bedrooms [1]. There are a range of machine learning models that can be adapted to perform either classification or regression, and they are suitable for different settings. I’m going to explore some categories of models here, with a focus on models that are well suited to small datasets. Let’s start exploring!</p>
<h2 id="linear-models">Linear Models</h2>
<p>We’ll start by considering the classic <strong>linear regression</strong> model because this will help ground us in thinking about regression and it will serve as a good point of comparison with more advanced models in a minute. In general, a regression algorithm is trying to learn a function that maps \(X \rightarrow y\), i.e. \(y = f(X)\). In linear regression, that function is simply a line, written as \(y = w_{1:n}X + w_0\) (we assume that \(X\) may be a matrix of data points with multiple features, and that \(w_{0:n}\) is a vector of coefficients). To find the coefficients, we need to solve the following minimization problem [2]:</p>
\[\min_{w} \|Xw - y\|_2^2\]
<p>We can make this model more robust by adding different forms of <a href="https://sassafras13.github.io/Regularization/">regularization</a>, including Ridge regression (the L2 norm) and LASSO regression (the L1 norm). The Ridge regression term (that we would add to the minimization expression above) can be written as \(\alpha \|w\|_2^2\) where larger values of \(\alpha\) will increase the regularization penalty and force the coefficients to be smaller. With Ridge regression, no term is ever driven completely to zero, in contrast with LASSO, which can be used to create more sparse linear models that eliminate some features completely. This is possible because LASSO regression (Least Absolute Shrinkage and Selection Operator) uses the L1 norm which can drive coefficients to zero, i.e. \(\alpha \|w\|_1\) [2].</p>
<h2 id="support-vector-machines">Support Vector Machines</h2>
<p>Support vector machines (SVMs) look for a hyperplane that separates data points while maximizing the gap (or margin) between any data point and the hyperplane [3]. For example, given a dataset \(X\) and a set of binary labels, \(y\), we want to find a hyperplane that separates the points \(X\) labeled \(y = 1\) from those labeled \(y=0\) by the maximum available margin. We can enforce this separation as a <strong>hard margin</strong>, where no points in \(X\) may fall on the wrong side of the hyperplane, or as a <strong>soft margin</strong>, where points are allowed to lie on the wrong side of the hyperplane, subject to some penalty [3].</p>
<p>In a linear example, the hyperplane would be defined by [3]:</p>
\[w^T X - b = 0\]
<p>Where \(w\) is the normal vector to the hyperplane and \(b\) is the offset. The normal vector learned by the SVM can be used to make predictions, thereby performing regularization as follows [3]:</p>
<p>\(\min \frac{1}{2} \|w\|^2\)
\(\text{subject to } |y_i - \langle w, x_i \rangle - b| \leq \epsilon\)</p>
<p>Where \(\epsilon\) is a hyperparameter that defines the margin of error - all predictions, \(\langle w, x_i \rangle + b\) must be within \(\epsilon\) of the true value \(y_i\) [3].</p>
<p>Note that SVMs can also generalize well to nonlinear situations through use of the <strong>kernel trick</strong>. This technique uses some kernel to transform the dot product \(w^T X\) with a nonlinear function that transforms the feature space to some higher-dimensional representation. In this high-dimensional space, it may be easier to find a hyperplane that separates the two categories than in lower-dimensional space [4].</p>
<p>Some advantages to using SVMs are [5]:</p>
<ul>
<li>Work well for high-dimensional data, even when the number of dimensions exceeds the number of data points available.</li>
<li>Only a subset of the data points are used to find the hyperplane (i.e., only those closest to the hyperplane) so this is a memory efficient technique.</li>
<li>The fact that we can swap in different kernels makes this a versatile model.</li>
</ul>
<p>And some disadvantages [5]:</p>
<ul>
<li>If the number of features is significantly larger than the size of the dataset, be careful of overfitting. This affects the choice of kernel and regularization term.</li>
<li>SVMs do not compute probabilities directly, if that’s something you need from your model.</li>
</ul>
<h2 id="nearest-neighbors">Nearest Neighbors</h2>
<p>The basic idea behind the nearest neighbors algorithm is to find some predetermined number of training samples (i.e. from \(X\)) that are “close” to the query point, and predict the label for the new point from these neighbors. There are multiple choices for the number of neighbors, and the metric for how “close” they are to the query point, yielding many variants on this approach. Commonly, we define a constant number of neighbors, \(k\), which gives us “k-nearest neighbor learning”, and we can use measures like the Euclidean distance (L2 norm) to compute how close two points are to each other. Since this algorithm applies labels simply by comparing new points to existing points, but it does not fit an analytical expression with parameters, we call this a <strong>non-parametric</strong> model. In particular, this model is great for situations where the decision boundary (the line that divides classes) is very irregular [6].</p>
<p>When we consider the regression form of nearest neighbors, it is essentially the same as described above, with the specific clarification that the predicted “label” for the new query point is simply \(y\), some continuous value that represents the output of the function \(f(X)\). In some cases, it may be helpful to add weights to the algorithm such that certain neighboring points are given more weight than others, for example if they are closer to the query point [6].</p>
<h2 id="gaussian-processes">Gaussian Processes</h2>
<p>We recently encountered Gaussian Processes (GPs) in our discussion of <a href="https://sassafras13.github.io/BayesianOptimization/">Bayesian optimization</a> where I provide a brief overview of the technique. I will also try to write a post dedicated to them soon because they are cool applications of the normal distribution. For the purposes of this post, let’s just say that GPs are a collection of normal distributions that each correspond to a single data point, and the covariance matrix for all of these distributions conveys information about the probability that each function is the true function. The covariance can be computed using various kernel methods.</p>
<p>GPs can be used for regression by fitting a GP model to a training dataset, \(X\). Specifically, the parameters of the covariance function (i.e. the kernel) are optimized to fit the data. We can then make predictions from the model by drawing samples from the joint probability distribution represented by the fitted GP [7].</p>
<p>The advantages to using GPs are [7]:</p>
<ul>
<li>Since we are using a continuous probability distribution, we can easily interpolate between data points.</li>
<li>The predictions generated by a GP are probabilities so we can compute confidence intervals to accompany our predictions.</li>
<li>We can use different kernels to define the covariance of the GP which gives us versatility in this type of model.</li>
</ul>
<p>Some disadvantages include [7]:</p>
<ul>
<li>This kind of model is not sparse, it uses every data point available and every feature to make a prediction.</li>
<li>These models do not work well in high-dimensional spaces.</li>
</ul>
<h2 id="decision-trees">Decision Trees</h2>
<p>Decision trees are a popular form of ML model because they are very easy to interpret. This is because they use a series of yes/no decisions in a tree structure to take features of a data point and use them to predict the target value (i.e. they use features of \(X\) to predict \(y\)). Decision trees can be used to perform either classification or regression - in the case of regression, the leaves of the tree will take on continuous values corresponding to \(y\) [8].</p>
<p>The algorithm for building decision trees uses some measure of the benefit of splitting the data into categories by applying a threshold to individual features. For example, if we want to predict whether or not a sporting event will occur based on the weather (outlook, temperature, humidity and wind), we can examine how splitting the data based on whether the temperature was hot or cold will generate sub-groups of the data, and how much information we gain by doing so. Different forms of the decision tree model will use different metrics, for example, Gini impurity or information entropy, to compute which split (based on outlook, temperature, humidity or wind) will yield the greatest gain in information at each step in the tree [8].</p>
<p>Some of the advantages of decision trees include [8,9]:</p>
<ul>
<li>They are very easy to interpret because we can visualize all of the decisions that the model learns to make.</li>
<li>Minimal data preparation is needed to apply this model, although it will not work well if there are missing values in the dataset.</li>
<li>Works well with large datasets.</li>
<li>Decision trees can work with both numeric and categorical data natively.</li>
</ul>
<p>Some disadvantages include [8,9]:</p>
<ul>
<li>They can overfit so techniques like pruning or limiting the depth of the tree may be necessary to regularize the model.</li>
<li>A single decision tree can be unstable because small variations in the values in the dataset could lead to a completely different tree being generated. This can be addressed using ensemble methods as we will see in the next section.</li>
<li>No model is great at extrapolation but decision trees are particularly bad because they’re giving piecewise estimates of the function \(y = f(X)\).</li>
<li>There is no guarantee that the decision tree found during training will be globally optimal - another good reason to use an ensemble method.</li>
<li>Decision trees do not work well with unbalanced dataset (i.e. where most of the points fall into one category for a particular feature).</li>
</ul>
<h2 id="ensemble-methods-random-forest">Ensemble Methods (Random Forest)</h2>
<p>As we saw in the previous section, a single decision tree can be an unstable, unreliable model for a given dataset. It can help, in this case, to use an ensemble method that joins the predictions of several variations of a model (or models) together and returns some final prediction based on all the individual models’ predictions. I don’t want to cover all the approaches to ensemble learning here, instead I will focus on how we commonly build ensemble models out of decision trees - this ensemble version is called a random forest.</p>
<p>In a random forest, we build multiple decision trees, and we use a couple of techniques to introduce variability in how the trees are formed. We then take an average over all the predictions from all of the trees and use that averaged prediction as the final value for the entire random forest. The variability is introduced in two places: first by drawing a subset of training samples from the full dataset \(X\) (with replacement - this is called <strong>bootstrapping</strong>) and using this subset to train the tree. Second, we can choose to use a random selection of features in the dataset to train a particular tree, thereby generating some variation in how the tree can split branches [10].</p>
<h2 id="conclusion">Conclusion</h2>
<p>Thank you for reading along with me today. I realize that I did not dig into any one model in great detail, instead my goal was to remind myself of how each regression model worked and where they might be most useful. Thanks for reading.</p>
<h2 id="references">References</h2>
<p>[1] Brownlee, J. “Difference Between Classification and Regression in Machine Learning.” 11 Dec 2017. <a href="https://machinelearningmastery.com/classification-versus-regression-in-machine-learning/">https://machinelearningmastery.com/classification-versus-regression-in-machine-learning/</a> Visited 4 Jun 2022.</p>
<p>[2] “Linear Models.” Scikit-Learn 1.1.1. <a href="https://scikit-learn.org/stable/modules/linear_model.html">https://scikit-learn.org/stable/modules/linear_model.html</a> Visited 4 Jun 2022.</p>
<p>[3] “Support-vector machine.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Support-vector_machine">https://en.wikipedia.org/wiki/Support-vector_machine</a> Visited 4 Jun 2022.</p>
<p>[4] “Kernel method.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Kernel_method#Mathematics:_the_kernel_trick">https://en.wikipedia.org/wiki/Kernel_method#Mathematics:_the_kernel_trick</a> Visited 4 Jun 2022.</p>
<p>[5] “Support Vector Machines.” Scikit-Learn 1.1.1. <a href="https://scikit-learn.org/stable/modules/svm.html">https://scikit-learn.org/stable/modules/svm.html</a> Visited 4 Jun 2022.</p>
<p>[6] “Nearest Neighbors.” Scikit-Learn 1.1.1. <a href="https://scikit-learn.org/stable/modules/neighbors.html">https://scikit-learn.org/stable/modules/neighbors.html</a> Visited 4 Jun 2022.</p>
<p>[7] “Gaussian Processes.” Scikit-Learn 1.1.1. <a href="https://scikit-learn.org/stable/modules/gaussian_process.html">https://scikit-learn.org/stable/modules/gaussian_process.html</a> Visited 4 Jun 2022.</p>
<p>[8] “Decision tree learning.” Wikipedia. <a href="https://en.wikipedia.org/wiki/Decision_tree_learning">https://en.wikipedia.org/wiki/Decision_tree_learning</a> Visited 4 Jun 2022.</p>
<p>[9] “Decision Trees.” Scikit-Learn 1.1.1. <a href="https://scikit-learn.org/stable/modules/tree.html">https://scikit-learn.org/stable/modules/tree.html</a> Visited 4 Jun 2022.</p>
<p>[10] “Ensemble methods.” Scikit-Learn 1.1.1. <a href="https://scikit-learn.org/stable/modules/ensemble.html#forests-of-randomized-trees">https://scikit-learn.org/stable/modules/ensemble.html#forests-of-randomized-trees</a> Visited 4 Jun 2022.</p>I want to use this post to survey some common machine learning regression models to better understand the range of tools available and when we would choose certain models. Regression refers to finding a function that can predict a continuous output for a given input, for example predicting the price of a house given knowledge of its square footage and number of bedrooms [1]. There are a range of machine learning models that can be adapted to perform either classification or regression, and they are suitable for different settings. I’m going to explore some categories of models here, with a focus on models that are well suited to small datasets. Let’s start exploring!More Cooperativity Math!2022-05-30T00:00:00+00:002022-05-30T00:00:00+00:00http://sassafras13.github.io/CooperativityMath<p>We have already examined cooperativity and several models that capture this phenomenon in a <a href="https://sassafras13.github.io/Cooperativity/">previous post</a>, but today I would like to dive into the mathematics of these models in a little more detail. I think that the way these models are developed is a useful inspiration for representing multivalent systems. In this post, we’ll first introduce another handy probability distribution, the Gibbs distribution, and then we’ll use it to derive a simple model of ligand binding with no cooperativity, and compare this to the MWC, Pauling and Adair models that incorporate cooperativity. Let’s get to it!</p>
<h2 id="the-gibbs-distribution">The Gibbs Distribution</h2>
<p>Previously, we derived an expression for the <a href="https://sassafras13.github.io/BoltzmannDistribution/">Boltzmann distribution</a>, which conveyed the probability of observing a system in a particular state. The Boltzmann distribution can be written as [1]:</p>
\[p(E_i) = \frac{1}{Z} e^{\frac{-E_i}{k_B T}}\]
<p>The Gibbs distribution is different from the Boltzmann distribution because it includes information about how the system is coupled to a larger reservoir that contains particles and energy, whereas the Boltzmann distribution only assumed that the system was in contact with an energy reservoir. Adding a particle reservoir to the paradigm allows us to study ligand binding more realistically because it represents how the number of ligands available in the environment (both the system and the reservoir) affects the reaction [1].</p>
<p>Let me just state that the Gibbs distribution can be written as follows, so you can see the differences between this equation and the Boltzmann distribution [1]:</p>
\[p(E_s^{(i)}, N_s^{(i)}) = \frac{e^{-\beta(E_s^{(i)} - \mu N_s^{(i)})}}{\mathcal{Z}}\]
<p>We will define everything in this expression in a moment, but for now let’s just observe that this expression includes additional information about the number of particles \(N_s\) and their chemical potential, \(\mu\).</p>
<p>Now that we’ve seen how the Gibbs distribution is different from the Boltzmann distribution, let’s step through the derivation. Imagine that we are modeling a system as an open container of air in a warm room. The container is our system, and the room is our reservoir. The total number of particles in this environment is the sum of the particles in the system (s) and the reservoir (r), i.e. \(N_{tot} = N_s + N_r\). Similarly, the total energy is \(E_{tot} = E_s + E_r\) [1].</p>
<p>The key idea here is that we want to compute the probability of finding a given state of the system, which is expressed in terms of energy and number of particles. This probability is proportional to the number of states that the reservoir can take given that the system state is fixed [1]:</p>
\[p(E_s^{(1)}, N_s^{(1)}) \propto W_r(E_{tot} - E_s^{(1)}, N_{tot} - N_s^{(1)})\]
<p>This might seem a little odd, but remember that if we’ve fixed the state of the system, then the only thing that can vary is the state of the reservoir. I can use the definition of entropy, \(S = k_B \ln W\) or \(W = e^{S/k_B}\), to rewrite this expression as [1]:</p>
\[W_r(E_{tot} - E_s^{(1)}, N_{tot} - N_s^{(1)}) \propto e^{S_r(E_{tot} - E_s^{(1)}, N_{tot} - N_s^{(1)})/k_B}\]
<p>I can use the Taylor expansion to rewrite this entropy as [1]:</p>
\[S_r(E_{tot} - E_s, N_{tot} - N_s) \approx S_r(E_{tot}, N_{tot}) - \frac{\partial S_r}{\partial E} E_s - \frac{\partial S_r}{\partial N}N_s\]
<p>We can rewrite the derivatives in terms of thermodynamic identities, i.e. \((\partial S / \partial E)_{V,N} = 1/T\) and \((\partial S / \partial N)_{E,V} = - \mu/T\), then we can recast the probability of a given system state as [1]:</p>
\[p(E_s^{(1)}, N_s^{(1)}) \propto e^{-(E_s^{(1)} - \mu N_s^{(1)}) / k_B T}\]
<p>From here we can write the equality that is the Gibbs distribution by including the <strong>grand partition function</strong>, \(\mathcal{Z}\) [1]:</p>
\[p(E_s^{(i)}, N_s^{(i)}) = \frac{e^{-\beta(E_s^{(i)} - \mu N_s^{(i)})}}{\mathcal{Z}}\]
<p>Where \(\beta = \frac{1}{k_B T}\) and the grand partition function is the sum over all the states [1]:</p>
\[\mathcal{Z} = \sum_i e^{- \beta(E_s^{(i)} - N_s^{(i)} \mu)}\]
<p>As we saw <a href="https://sassafras13.github.io/BoltzmannDistribution/">previously</a>, it is possible to compute the average quantity (in this case the average number of particles) using a number of clever mathematical tricks and the grand partition function [1]:</p>
\[\bar{N} = \frac{1}{\beta} \frac{\partial}{\partial \mu} \ln \mathcal{Z}\]
<p>Now that we have the Gibbs distribution in hand, let’s use it to solve a simple problem of one ligand binding to one receptor, before moving on to more complex models that take cooperativity into account.</p>
<h2 id="simple-ligand-binding-problem-with-gibbs-distribution">Simple Ligand Binding Problem with Gibbs Distribution</h2>
<p>First, let’s introduce an important parameter that we’ll use repeatedly throughout this post: \(\sigma\). We’ll use \(\sigma\) as a binary variable that represents that either a binding event has occurred (\(\sigma = 1\)) or not (\(\sigma = 0\)). We’ll also define the binding energy as \(\epsilon_b < 0\) (negative to indicate that energy is released during the reaction) [1].</p>
<p>In this case, we can write the grand partition function as the sum of the two states of the system, the ligand bound to the receptor, and unbound [1]:</p>
\[\mathcal{Z} = \sum_{\sigma = 0}^1 e^{-\beta (\epsilon_b \sigma - \mu \sigma)} = 1 + e^{-\beta(\epsilon_b - \mu)}\]
<p>The average number of ligands bound can be solved for using the fact given above, \(\bar{N} = \frac{1}{\beta} \frac{\partial}{\partial \mu} \ln \mathcal{Z}\), or simple probability theory [1]:</p>
\[\bar{N} = \frac{ e^{- \beta(\epsilon_b - \mu)}}{1 + e^{- \beta(\epsilon_b - \mu)}}\]
<p>I can also replace \(\mu\) with the full expression for chemical potential, \(\mu = \mu_0 + k_B T \ln(c/c_0)\) and obtain [1]:</p>
\[\bar{N} = \frac{(c/c_0) e^{-\beta \Delta \epsilon}}{1 + (c/c_0) e^{-\beta \Delta \epsilon}}\]
<p>Where \(\Delta \epsilon = \epsilon_b - \mu_0\), and it represents the energy difference between the ligand in solution and the bound ligand [1]. Now that we have worked through this simple, we have all the ideas we need to develop some more detailed models, starting with a simplified dimeric binding model that takes cooperativity into account.</p>
<h2 id="dimeric-binding-model-including-cooperativity">Dimeric Binding Model Including Cooperativity</h2>
<p>Let’s assume that we have a receptor with two binding sites that can each, separately, bind or unbind with a ligand. The energy of this system is [1]:</p>
\[E = \epsilon(\sigma_1 + \sigma_2) + J\sigma_1 \sigma_2\]
<p>Where \(\epsilon\) is the energy of one binding event, \(\sigma_i\) describes the binding state of each of the two sites, and \(J\) is a measure of the cooperativity and indicates that when both sites are bound to ligands, the energy released is <em>different</em> from the sum of the energy of the two individual binding sites [1].</p>
<p>We can compute the grand partition function from summing over the four possible states of the system (both unoccupied, one of the two sites occupied, both sites occupied) as follows [1]:</p>
\[\mathcal{Z} = 1 + e^{-\beta(\epsilon - \mu)} + e^{-\beta(\epsilon - \mu)} + e^{-\beta(2\epsilon + J - 2\mu)}\]
<p>And as before, we can write the average number of binding sites occupied (i.e. average occupancy) as [1]:</p>
\[\bar{N} = \frac{2 e^{-\beta(\epsilon - \mu)} + 2 e^{-\beta(2 \epsilon + J - 2\mu)}}{1 + e^{-\beta(\epsilon - \mu)} + e^{-\beta(\epsilon - \mu)} + e^{-\beta(2\epsilon + J - 2\mu)}}\]
<p>This relatively simple example captures the cooperativity in the extra term \(J \sigma_1 \sigma_2\), and shows sigmoidal behavior as the partial pressure of the ligand increases. But there are, as we’ve seen, more complex representations of this cooperativity. One such example is the Monod-Wyman-Changeux (MWC) model that we have seen before, and which we will revisit now with our awesome mathematical tools.</p>
<h2 id="the-mwc-model-of-cooperativity-with-awesome-math">The MWC Model of Cooperativity with Awesome Math</h2>
<p>As we discussed <a href="https://sassafras13.github.io/Cooperativity/">previously</a>, the MWC model introduces the idea that the receptor can be in either a tense (T) or relaxed (R) state, where the T state is favored over the R state when no ligands are bound to the receptor. There is a cost, \(\epsilon\), associated with switching to the R state from the T state. Similarly, we have binding energies associated with each state, \(\epsilon_T\) and \(\epsilon_R\). Each state also has \(\sigma_i\) to represent whether each of the two binding sites is occupied, and another variable \(\sigma_m\) to represent whether the receptor is in the T (\(\sigma_m = 0\)) or R states (\(\sigma_m = 1\)). Given all of this, the energy of the system can be written as [1]:</p>
\[E = (1 - \sigma_m) \epsilon_T \sum_{i=1}^2 \sigma_i + \sigma_m \left( \epsilon + \epsilon_R \sum_{i=1}^2 \sigma_i \right)\]
<p>To compute the grand partition function, we need to sum up over 8 unique cases (4 for each state, T or R) [1]:</p>
\[\mathcal{Z} = 1 + 2 e^{-\beta(\epsilon_T - \mu)} + e^{-\beta(2 \epsilon_T - 2 \mu)} + e^{-\beta \epsilon} ( 1 + 2 e^{-\beta(\epsilon_R - \mu)} + e^{-\beta(2 \epsilon_R - 2\mu)})\]
<p>From this, the average occupancy is [1]:</p>
\[\bar{N} = \frac{2}{\mathcal{Z}} [x + x^2 + e^{-\beta \epsilon} ( y + y^2)]\]
<p>Where \(x = (c/c_0)e^{-\beta(\epsilon_T - \mu_0)}\) and \(y = (c/c_0) e^{-\beta(\epsilon_R - \mu_0)}\). As we discussed before, although this model does not explicitly capture cooperativity, the emergent behavior of the system transitioning between T and R states will result in cooperative-like behavior. In the next section, we consider a model for a receptor with four binding sites but no cooperativity, and then add the cooperativity to build the Pauling model.</p>
<h2 id="four-binding-sites-without-cooperativity-then-the-pauling-model">Four Binding Sites Without Cooperativity, then the Pauling Model</h2>
<p>In this section we establish the mathematics for the occupancy of a receptor with 4 binding sites. First, we can write the energy as [1]:</p>
\[E = \epsilon \sum_{\alpha = 1}^4 \sigma_{\alpha}\]
<p>And since there is no cooperativity, each binding event is completely independent of the others, and we can sum up all 16 possible states as [1]:</p>
\[\mathcal{Z} = \sum_{\sigma_1 = 0}^1 \sum_{\sigma_2 = 0}^1 \sum_{\sigma_3 = 0}^1 \sum_{\sigma_4 = 0}^1 e^{-\beta(\epsilon - \mu) \sum_{\alpha=1}^4 \sigma_{\alpha}}\]
<p>These terms can be separated out and then combined (since they’re all based on the same binding energy) into the simpler expression [1]:</p>
\[\mathcal{Z} = (1 + e^{-\beta(\epsilon - \mu)})^4\]
<p>And the occupancy is [1]:</p>
\[\bar{N} = \frac{4 e^{-\beta(\epsilon - \mu)}}{1 + e^{-\beta(\epsilon - \mu)}}\]
<p>Now let’s modify this model to include cooperativity, thereby developing the Pauling model. In this case, we are going to add a term that adds some contribution \(J\) if two binding sites are both occupied. Specifically [1]:</p>
\[E = \epsilon \sum_{\alpha = 1}^4 \sigma_{\alpha} + \frac{J}{2} \sum_{(\alpha, \gamma)}’ \sigma_{\alpha} \sigma_{\gamma}\]
<p>Here, \(\alpha\) and \(\gamma\) are indices of two binding sites and we add energy \(J\) if they are both occupied. Notice that the \(\sum’\) indicates that we should not include this term if \(\alpha = \gamma\) and we divide \(J\) in half to account for the double contributions of \(\sigma_1 \sigma_2\) and \(\sigma_2 \sigma_1\) [1].</p>
<p>Again, this energy can be used to compute the grand partition function [1]:</p>
\[\mathcal{Z} = \sum_{\sigma_1 = 0}^1 \sum_{\sigma_2 = 0}^1 \sum_{\sigma_3 = 0}^1 \sum_{\sigma_4 = 0}^1 e^{-\beta(\epsilon - \mu) \sum_{\alpha = 1}^4 \sigma_{\alpha} - \beta(J/2) \sum_{\alpha, \gamma}’ \sigma_{\alpha} \sigma_{\gamma}}\]
<p>And this can be evaluated as a sum of 5 terms that represent 0 to 4 binding sites occupied, and 16 total states [1]:</p>
\[\mathcal{Z} = 1 + 4 e^{-\beta (\epsilon - \mu)} + 6 e^{-2\beta (\epsilon - \mu)-\beta J} + 4 e^{-3\beta (\epsilon - \mu)-3\beta J} + e^{-4\beta (\epsilon - \mu)-6\beta J}\]
<p>And then the occupancy becomes [1]:</p>
\[\bar{N} = \frac{4x + 12 x^2 j + 12x^3 j^3 + 4x^4 j^6}{1 + 4x + 6x^2 j + 4x^3 j^3 + x^4 j^6}\]
<p>Where \(x = (c/c_0) e^{-\beta(\epsilon - \mu_0)}\) and \(j = e^{-\beta J}\) [1]. This model requires fitting only two free parameters, but the cost is that it assumes all interactions are paired, and does not take into account interactions among 3 or more binding sites. The final model we consider, the Adair model, will include these as well.</p>
<h2 id="the-adair-model-for-cooperativity">The Adair Model for Cooperativity</h2>
<p>The Adair model uses the same overarching framework as the Pauling model, but includes interactions between 3 binding sites (captured by parameter \(K\)) and 4 sites (\(L\)). The math will get a little more involved to account for them, but the patterns are the same. Let’s start with the energy for the model [1]:</p>
\[E = \epsilon \sum_{\alpha = 1}^4 \sigma_{\alpha} + \frac{J}{2} \sum_{\alpha, \gamma}’ \sigma_{\alpha} \sigma_{\gamma} + \frac{K}{3!} \sum_{\alpha, \beta, \gamma}’ \sigma_{\alpha} \sigma_{\beta} \sigma_{\gamma} + \frac{L}{4!} \sum_{\alpha, \beta, \gamma, \delta} \sigma_{\alpha} \sigma_{\beta} \sigma_{\gamma} \sigma_{\delta}\]
<p>The grand partition function this time is a little intimidating but we can simplify it, as shown below in two steps [1]:</p>
\[\mathcal{Z} = \sum_{\sigma_1 = 0}^1 \sum_{\sigma_2 = 0}^1 \sum_{\sigma_3 = 0}^1 \sum_{\sigma_4 = 0}^1 \exp \left[-\beta(\epsilon - \mu) \sum_{\alpha = 1}^4 \sigma_{\alpha} - \frac{J}{2} \sum_{\alpha, \beta}’ \sigma_{\alpha} \sigma_{\beta} - \frac{K}{3!} \sum_{\alpha, \beta, \gamma}’ \sigma_{\alpha} \sigma_{\beta} \sigma_{\gamma} - \frac{L}{4!} \sum_{\alpha, \beta, \gamma, \delta}’ \sigma_{\alpha} \sigma_{\beta} \sigma_{\gamma} \sigma_{\delta} \right]\]
\[\mathcal{Z} = 1 + 4e^{-\beta(\epsilon - \mu)} + 6e^{-2\beta(\epsilon - \mu) - \beta J} + 4 e^{-3\beta(\epsilon - \mu) - 3 \beta J - \beta K} + e^{-4 \beta(\epsilon - \mu) - 6 \beta J - 4 \beta K - \beta L}\]
<p>And finally, the occupancy is [1]:</p>
\[\bar{N} = \frac{4x + 12x^2 j + 12 x^3 j^3 k + 4 x^4 j^6 k^4 l}{1 + 4x + 6x^2j + 4x^3 j^3 k + x^4 j^6 k^4 l}\]
<p>Where \(k = e^{-\beta K}\) and \(l = e^{-\beta L}\).</p>
<h2 id="conclusion">Conclusion</h2>
<p>That wraps up our detailed, math-intensive, deep dive into models of cooperativity here. I think my most useful takeaway is to see how we can incorporate cooperativity with extra terms in our energy expressions. I think this method also creates interesting opportunities for simplifying larger situations where there are enough binding sites to make this manual approach a little too cumbersome… Thanks for reading!</p>
<h2 id="references">References</h2>
<p>[1] Phillips, R., Kondev, J., Theriot, J., Garcia, H. Physical Biology of the Cell, 2nd Edition. Garland Science, 2013.</p>We have already examined cooperativity and several models that capture this phenomenon in a previous post, but today I would like to dive into the mathematics of these models in a little more detail. I think that the way these models are developed is a useful inspiration for representing multivalent systems. In this post, we’ll first introduce another handy probability distribution, the Gibbs distribution, and then we’ll use it to derive a simple model of ligand binding with no cooperativity, and compare this to the MWC, Pauling and Adair models that incorporate cooperativity. Let’s get to it!