On and off over the last couple of months, I have been working on implementing the Adaptive Computation Time (ACT) work by Alex Graves (2016). I think it is a fairly simple but potentially novel concept, that has since seen several further developments (1, 2). In this post, I will be detailing how the ACT works, its motivation, and my thoughts on it.

Regular RNNs

Graves' ACTs were first applied to RNNs, so that is a good place to start.

The above graphic should be familiar to most readers. RNNs can be formulated in a number of ways, from one-to-one to many-to-many, but they all share a similar underlying organization. Generally, the RNN consumes one input token/vector and the previous hidden state at each step, and outputs a hidden state/output vector. The generated hidden/output vector can be directly tied to the target of the problem, or simply consumed by another neural network, possibly another RNN. As seen above, not all hidden/output vectors are used in all cases.

Within the above formulation, one key underlying assumption is that the RNN consumes one input vector at each step. Or, put another way, the RNN maps the input and previous hidden state to the output vector in a single step. In other words:

$$ h_t = RNN\left(h_{t-1}, x_t\right) $$

where $h_t$ are the sequence of hidden states, and $x_t$ the input. The $RNN$ can be a vanilla RNN or an LSTM/GRU.

This formulation is actually fairly constrainted - it means that all the information from $x_t$ must be consumed and processed within a single call of the $RNN$ function, particularly if the generated $h_t$ are directly consumed (e.g. mapped directly to a target output). Remember that vanilla RNNs are essentially single-layer feed-forward networks with a non-linear activation. LSTMs/GRUs have additional machinery to gate/propogate past hidden states, but the core computation is still performed in a single step.

Of course, with a sufficiently large RNN, any function could be roughly approximated, but part of the strength of deep neural networks is that growing deeper networks can be more efficient than wider ones - which is why we have architectures like ResNet, rather than a 65536-dimension sized, single-layer CNN.. Hence the question is: how can we make the RNN computation artificially deeper at each step?

There are several ways of mitigating this. One approach is to stack several RNN layers on top of each other - this could allow more steps for computation in the interim, as well as more interaction across hidden states. This expands number of computations per input vector, but the number of computations is ultimately still deterministically tied to the length of the input vector.

Moving away from RNNs, networks like ResNets approximate a solution. Architectures like the 1001-layer ResNets do not necessarily expect to use all 1001 layers - rather, the layers are initialized to be identity (i.e. no computation is performed), and rather the 1001-layers provide the capacity for additional learned computations if necessary.

Adaptive Computation Time

ACT provides an explicit mechanism for adaptively adjusting the number of RNN computations at each input step, addressing the above constraint. The formulation is a little involved, but the idea is actually fairly straightfoward, so stick with me.

Intermediate Computation Steps

First, we need to introduce how we will allow for multiple computation steps at the same input step. Where before we used $h_t$ to denote the hidden state at time $t$, we now introduce $h_t^n$, the intermediate hidden state for input-step $t$ and intermediate step $n$, where $n$ goes from $1\cdots N(t)$, and $N(t)$ is the number of intermediate computation steps for that input step. Formally, the $h_t^n$ are computed as follows:

\begin{equation} h_t^n=\begin{cases} RNN(h_{t-1}, x_t^1) & \text{if $n=1$}.\\ RNN(h_t^{n-1}, x_t^n) & \text{otherwise}. \end{cases} \end{equation}

where $x_t^n$ is $x_t$ with binary flag to indicate if this is the initial intermediate computation step, or a subsequent intermediate computation step with the same input. You can follow along the formulation to observe that the $h_t^n$ are created in sequence $h_t^1 \cdots h_t^{N(t)}$, where each intermediate state depends on the last, and the first $h_t^1$ depends on $h_t$. Two things to note here:

  • We have not yet specified how $N(t)$ or $h_t$ are derived. We will do so soon.
  • Strictly speaking, we may not need the $x_t^n$. One could image that we could simply provide the $x_t$ on the first intermediate computation step $n=1$, and subsequently simply provide some other filler value (e.g. all zeros). We would still need the binary flag to indicate if this is a "repeated" input value. It is unclear whether this approach is better than the above, but the presumption is that by providing a fresh copy of the input each time, further computation steps can work directly off the input $x_t$ rather than having to carry all that information within the hidden states.

(Separately, I mentioned above that $x_t^n$ is $x_t$ augmented with a binary flag, but I actually had good results augmenting it with $n$ instead.)

Determining the Number of Intermediate Computation Steps

Next, we need to determine $N(t)$, the number of intermediate computation steps for each input step. The key idea is to introduce a mechanism that computes a halting probability, which is toused determine when the RNN is done processing the current input, and can proceed to the next. We introduce and apply the following sigmoidal computation at each step of the RNN computation:

$$ \text{halt}_t^n = \sigma\left(W_h h_t^n + b_h\right) $$

This is a simple linear function of the current hidden state, mapped to a probability via a sigmoid. While we now have a sequence of $\text{halt}_t^n$, we need to convert this into a probability distribution to make it a proper halting probability. We will take the following approach:

  1. Generate $\text{halt}_t^n$ at each intermediate step $n$
  2. When we reach $M$ such that the cumulative halting probabilities $\sum_{n=1}^M\text{halt}_t^n$ is larger than 1, we have something close to a probability distribution. We set $N(t) = M$. (In actuality, we will use $1-\epsilon$ instead of 1. See below for details.)
  3. We define the halting probabilities $p_t^n$ as the corresponding $\text{halt}_t^n$, except for the last halting probability $\text{halt}_t^{N(t)}$, which we set to be $p_t^{N(t)} = R(t)=1-\sum_{n=1}^{N(t)-1} p_t^n$, in other words the remainder from all the other halting probabilities. This ensures that our $p_t^n$ sums to 1, so we have a valid probabilities distribution.

This is summarized in the following:

\begin{equation} p_t^n =\begin{cases} R(t) & \text{if $n=N(t)$}.\\ \text{halt}_t^n & \text{otherwise}.\\ \end{cases} \end{equation}$$ R(t) = 1-\sum_{n=1}^{N(t)-1}h_t^n $$$$ N(t) = \min\left\{m: \sum_{n=1}^m \text{halt}_t^n \geq 1-\epsilon \right\} $$

A couple more notes:

  • The reason we use $1-epsilon$ rather than $1$ for determining when to halt further halting probabilities computations is because the sigmoid function only reaches $1$ in the limit. In other words, it if we set it to $1$, it would be impossible for computation to halt in a single step, and thus impossible for the ACT to reduce to an RNN. Since our goal is to have the ACT be a generalization of the RNN, we allow the margin of $\epsilon$ to allow the ACT to halt within 1 step if needed.
  • In this case, the author chose to truncate the $\text{halt}_t^{N(t)}$ to form $R(t)$ to make the probabilities sum to 1. We could also concieve of an alternative where instead of truncating, we renormalized to 1. This makes more intuitive sense to me, but I have not experimented with this formulation.

Completing the ACT

Finally, we need to compute the $h_t$, which is used in the first intermediate computation step of the next input step. Intuitive, since we mentioned "halting probabilities", we ought to be sampling from the $h_t^n$ via $p_t^n$. However, sampling introduces various forms of noise and errors to the process, and so the author opted to instead compute $h_t$ as a weighted mean of the $h_t^n$:

$$ h_t = \sum_{n=1}^{N(t)} h_t^n p_t^n $$

While this may seem like a simple shortcut, this actually introduces a neat analogy between the ACT and attention mechanisms, which are similarly weighted sums! I will discuss this below, but also note that there has been later work to make the ACT more probabilistic (Figurnov et al, 2017).

Limiting and Penalizing Computation Time

We would like to limit the additional intermediate computation time by the ACT for both practical (reducing training time) and theoretical (limiting model complexity) reasons. This is done via 2 measures

  1. Adding a penalty term to the loss for additional computation time.
  2. Setting a hard-cap of $N(t)$ to some max $N$. This is a more blunt measure, and more practically useful near the start of training to ensure a maximum number of intermediate steps when the model is still ill-formed.

The latter is fairly straightforward, but the former warrants a little more explanation. The author defines the ponder cost as :

$$ \rho_t = N(t) + R(t) $$

$\rho_t$ is defined as such as we want to reduce both the number of intermediate computation steps $N(t)$ and the "remainder" halting probability $R(t)$. (This formulation is actually a little awkward, as it is not clear why these two terms should be added together, but as we shall see shortly, this does not actually matter.)

The $\rho_t$ are summed over the input time steps to form the ponder penalty term.

$$ \mathcal{P}(x) = \sum_{t=1}^T \rho_t $$

As with most additional penalty terms, it introduces an additional hyper-parameter ($\tau$) for the relative weight of the penalty.

$$ \hat{\mathcal{L}}(X,y) = \mathcal{L}(x, y) + \tau \mathcal{P}(x) $$

A good deal of the paper is actually running experiments on what value of $\tau$ to use, which impacts how much additional intermediate computation is performed by the ACT models. The author concludes that the performance of the models are highly sensitive to $\tau$, and the behavior varies across different problems, so it remains a fairly significant issue with the model.

One additional note - referring back to the formulation of $\rho_t$, we note that $R(t)$ is differentiable, but $N(t)$ is not. This means that we cannot simply perform gradient descent through the $\rho_t$. The author uses the compromise of simply only retaining the $R(t)$ and ignoring the $N(t)$ in the penalty, and the model seems to work reasonably in practice. Simply passing the gradients through $R(t)$ has an interesting intuition. Remember that:

$$ R(t) = 1-\sum_{n=1}^{N(t)-1}h_t^n $$

If we incentivize the optimizer to reduce $R(t)$, what it does it it increases $h_t^n$ from $1 \cdots N(t)$ - in other words, increasing the halting probability equally for all prior steps. This means that the idea to only consider $R(t)$ in the penalty is actually sound - we penalize the model for its final intermediate computation, and in doing so will incentivize it to halt early in subsequent runs. (This also justifies the use of a remainder $R(t)$, rather than renormalizing the halting probabilities which I questioned above.)

On the other hand, this is also a fairly slow way of penalizing the computation steps, working only on the margin and equally penalizing all prior steps. While it is theoretically sound, I wonder if there is a better way to formulate this penalty to better tune the model's need for additional computations.

Implementation and Experiments

I implemented the Adaptive Computation Time model in PyTorch, and the source code can be found at https://github.com/zphang/adaptive-computation-time-pytorch.

The implementation with PyTorch was not necessarily difficult, but it does involve some non-conventional operations in allowing for batch-computations, particularly with different $N(t)$ for different samples in a batch. Furthermore, because we have to directly modify the implementation of RNNs, we lose the performance advantage of cuDNN implementations.

I also intended to replicate the specific experiments described in the paper, although this is still ongoing work.

Below, we see the plot of that corresponds to Figure 6 in the paper, replicating the first toy task - on parity determination. The task is a little artificial, and is as follows: We are given a single time step of a 64-dimension vector. This vector contains 1s and -1s, followed by a number of zeros (the number of 1s/-1s is referred to as the "difficulty" of that vector, and is uniformly distributed). We wish to determine the parity of the 1s/-1s, i.e. whether there was an odd or even number of 1s.

Because this is formulated as a single time step, the RNN basically reduces to a single-layer feed-forward network, whereas the ACT is able to dynamically create more intermediate computation steps. Hence this task is artifically created to ensure that the ACT does indeed take advantage of more intermediate computation steps when necesary.

We apply this task to a basic RNN, and the ACT over a variety of ponder cost penalties. Intuitively, we should expect better performance for lower ponder cost penalties, and we should expect that for higher ponder cost penalties, the performance should approach that of the RNN.

This is in fact what we observe. (Note that our models are trained slightly differently from that described in the paper, as we do not use the hogwild algorithm.) Note that the X-axis is the difficulty (number of 1s/-1s), while the Y-axis is the error rate. We run the experiment over 20 trained models for each configuration, and average the results. The faint lines are results from individual runs, while the bold lines are the averages.

We observe that for low difficulty, all models perform well with close to 0 error rate, while for high difficulties, the RNNs and high ponder cost ACT models start to underperform, eventually reaching a 50% (randomc chance) error rate. On the other hand, the low ponder-cost ACTs continue to perform very well even as the difficulty increases. This confirms to us that the ACT is working as expected, both in using applying more intermediate computation steps as required, and trading off the additional ponder time and the cost of the penalty.

Discussion

Relationship with Attention

Recall that we have 2 forms of hidden states:

  1. $h_t^n$, which are the intermediate computation steps, for which there are $N(t)$ of them for $T$ input time steps
  2. $h_t$, which is the canonical hidden state, for which there are $T$ of them, one per input time step We also specified that $h_t$ is a weighted average of $h_t^n$ by the $p_t^n$, rather than a sampled version. The $h_t$ is also used to seed the hidden state for the next time step's computation.

Given that the $h_t$ is a weighted average of $h_t^n$, it is natural to draw comparisons to the attention mechanism, which also involves weighted sums of hidden states. As we will see, they are actually quite different, but it would not be infeasible to formulate a broader generalization that captures both mechanisms.

One first thing to note is that ultimately, the consumers of the output of the ACT will only see the $h_t$ and not the $h_t^n$. This means that the $h_t$s are not just importing for seeding the ACT computations for the next time step, but directly the only output from the ACT that is actually consumed.

Here is a quick run-down the attention mechanism:

  • Start with $h_t^E$, the set of $T$ encoder (input) hidden states
  • At each time step $s$ of the decoder, we have the decoder hidden state $h_s^D$ from the previous time step
  • We compute attention score $a_{s, t}$ as a function of the current decoder hidden state $h_s^D$, every encoder hidden state $h_t^E$, and the previous decoder output $y_{s-1}$. This represents the affinity between the current decoder state and available encoder states. The exact formulation of this attention score varies.
  • The attention scores $a_{s, t}$ are renormalized to sum to 1 via a softmax. This gives us a "distribution" over encoder hidden states.
  • We take a weighted sum of the encoder hidden states $h_t^E$, weighted by the normalized attention scores. This gives us a context vector $c_s$.
  • We compute the next decoder hidden state as a function of $h_s^D$, $c_s$ and $y_{s-1}$.

This runs a little long, but in summary, we computed a weighted sum of the encoder hidden states, weighed by their affinity with the current hidden state, and use that as an input in our RNN.

Now, let us note the key similarities and comparisons between the ACT and attention:

  1. Both mechanisms take weighted sums of encoder/RNN hidden states to produce a set of "hidden states" (in the case of attention, it produces the context vector $c_s$). Let us call these "hidden-output states"
  2. The ACT produces $T$ hidden-output states while attention produces $S$ hidden-output states, where $T$ is the length of the output sequence, and $S$ is the length of the output sequence. Note that in the case of ACT, there may be no notion of a hidden output sequence at all (there may be only a singular RNN, rather than an encoder-decoder paradigm).
  3. The weights in ACT are distributed over the intermediate computation steps of the ACT, while the weights in attention are distributed over all encoder hidden states.
  4. The weights in ACT are calculated via a sigmoidal function over intermediate hidden states, and stops when the sum exceeds $1-\epsilon$, whereas the attention weights are computed as a function of the decoder and every encoder hidden state, and then normalized by a softmax.

Hence, we should see that while both mechansisms compute weighted averages over hidden states, they are really not that similar after all. In particular, the ACT weighting mechanisms is more auto-regressive and local, while the attention mechanism is more global, covering all encoder hidden states at each step. The goal of ACT is to produce a dynamically determined larger number of computation steps, whereas the attention mechanism computes the same number of attention sequence computation each time (although the number of decoder steps is usually variable).

Nevertheless, it is worth considering how the two mechanisms could be interacted. Could the ACT tick on decoder timesteps in an encoder-decoder format (or even be detached from either set of time steps)? Could we simply use the final halting state and omit the any weighting at all? Or could we compute a separate set of "attention" rather than using $\text{halt}_t^n$ to both determine the halting probability as well as the weight in the weighted sum? These are all avenues for future research

Conclusion

Adaptive Computation Time appears to be a relatively reasonable paradigm for dynamically determining additional computation steps on inputs, but its implementation is pretty involved, and I wonder if there may be ways to improve it. In particular, the author states that the model is highly sensitive to the ponder cost penalty weight, which is a little concerning. Nevertheless, it demonstrates its ability to adaptively increasing the number of intermediate computations, and two extensions of ACT have been proposed and shown positive results. I look forward to further work in the topic.