Published on

(Shower Thoughts) Stochastic Search is All You Need—Guidance for Discrete Diffusion with Non-Differentiable Objectives


Quick note on the blog

I have not added to the blog for quite sometime, which I find kind of a shame. Part of the bottleneck is that in my head I've kept this high standard for what is acceptable—it should be very detailed, in depth, while explaining the main ideas clearly. While I hope to keep this standard for most of my research blog posts, I also think there is value in sharing ideas that might be off the cuff and less polished. This is one example of that: I wrote this write up to summarize the idea I had for a hackathon project.

Introduction

Diffusion is a powerful generative model that learns the data distribution to generate samples. However, in many settings, the space of possible outputs is too large---thus there is a need to control the generation process to get a desired output. Luckily diffusion models are by nature controllable---at inference time, the diffusion model at each time step outputs a sample in the output data space, allowing it to be plugged in to some guidance or value function to get signal from. This is the idea of classifier guidance, which is powerful because it allows for controllable generation with no additional training (the guidance happens at inference time). Moreover, the guidance function can be easily swapped in and out to tailor to different settings. While classifier guidance works out quite nicely in most settings (ex: gaussian diffusion models incorporate the gradient of the differentiable classifier into gaussian sampling mean), it's not so obvious in many realistic settings, such as when the data is discrete (gradients are tricky to get with discrete inputs) or if the guidance function is non-differentiable and not easily approximated (ex: running some software or simulation code). This post summarizes a naive idea to work around this non-differentiable and discrete guidance setting: if you can't do gradient descent, why not find through stochastic search the perturbation that leads to higher values?

What would we like to do

In the continuous setting with differentiable guidance functions, it is easy to add guidance to a diffusion process at inference time. At a high level, we can simply modify the diffusion process by incorporating the gradient of the guidance function. At heart, we have the following equation

xlogp(xy)=xlogp(yx)+xlogp(x)\nabla_{x}logp(x|y)=\nabla_{x}logp(y|x)+\nabla_{x}logp(x)

Under particular assumptions, there is a direct connection between diffusion models and score-based models; the diffusion model is essentially learning the second term xlogp(x)\nabla_{x}logp(x) and the reverse diffusion process moves the generated sample xtx_{t} to regions of higher data likelihood by adding the score xlogp(x)\nabla_{x}logp(x). Now we can also ensure that the generated sample is moving to regions with higher value; given some classifer or guidance function p(yx)p(y|x) we can also adds its score xlogp(yx)\nabla_{x}logp(y|x). This obviously requires the guidance funtion p(yx)p(y|x) to be differentiable. In the typical gaussian continuous diffusion setting, the math works out nicely and it's enough to directly add the score of our guidance function to the gaussian mean parameter:

Classifier guidance for gaussian diffusion

Our work around

But if our guidance function is non-differentiable, then we can't just calculate the score through a gradient calculation. For example, with protein sequence generation, the diffusion model can output at each reverse timestep protein sequences (discrete data) which we want to guide. Now suppose that we want our generated proteins to bind to some other protein and we have this software that can take in a sequence, run some simulations, and output its binding affinity score.

Protein docking software

Here we have the guidance function at our disposal---but we can't get a gradient from it. So how can we guide our discrete sequences?

It might be difficult to work directly with discrete sequences (for example it's hard to tell which sequences are close to a given sequence) so we can instead consider the continuous logits that are decoded to a sequence.

Universal Guidance for Diffusion Models (Bansal et al. 2023) add guidance to images by doing the typical classifier guidance setup ϵ^(zt,t)=ϵ(zt,t)1αtlogp(czt)\hat{\epsilon}(z_{t},t)=\epsilon(z_{t},t)-\sqrt{1-\alpha_{t}}\nabla logp(c|z_{t}) and also doing a local optimization around the denoised diffusion output. Concretely, if at some timestep our diffusion outputs some clean image z0^\hat{z_{0}}, it uses gradient descent to do this local search and calculate

Δz0=argminΔl(c,f(z0^+Δ))\Delta z_{0}=argmin_{\Delta}l(c,f(\hat{z_{0}}+\Delta))

Now if our guidance function ff is non-differentiable then we can't do gradient descent but this motivates a similar idea: why not just try to do local bruteforce search around your given diffusion output z0^\hat{z_{0}} and return the best performing perturbation f(z0^+Δ)f(\hat{z_{0}}+\Delta^{*}).

Specifically, one naive way of doing this is for any output z0^\hat{z_{0}} to stochastically sample (ex: from gaussian noise ΔN(0,σ)\Delta \sim N(0,\sigma) ) and generate many perturbations Δ\Delta and return the argmax. This is a simple first step: further improvements could be to ensure that our perturbations are diverse (but in high dimensional space, we assume with high probability that random samples don't get cloned) or to possibly do something smarter by having the argmax collapse the perturbations not in one-step but across multiple steps, so we search through a tree instead through a list of perturbations.

Experimental results on protein sequences

For the hackathon, I built on an existing repository that contained discrete diffusion code for protein sequences. The particular discrete diffusion model that I trained was a masked langauge model discrete diffusion model, which is essentially a BERT model trained to unmask masked tokens from some corrupted sequence. I also trained a value function to optimize a protein characteristic called "SASA". While it would be amazing if the value function was not a neural network (because the main utility for doing stochastic search is to work with non-differentiable guidance functions), the method holds for both differentiable and non-differentiable functions.

For the experiment, we add guidance at each timestep by taking the argmax of f(ht+Δ)f(h_{t}+\Delta) where ΔN(0,12)\Delta\sim N(0,\frac{1}{2}) where hth_{t} is the logits that become decoded into sequence xt.x_{t}. For each hth_{t} we create M=100M=100 perturbations. With an eyeball check, these logits typically have magnitude larger than 10 so this perturbation is not unreasonable.

ModelSASA Value Score
With guidance1.286
Without guidance0.829

(Results averaged across batch size 1024)

This doesn't say much: with guidance, we take the argmax across multiple perturbations so we expect the value to increase.

To do a robust and fair comparison, we should consider the alternative---generating samples and then evaluating the samples to take the argmax. The guidance process adds TMT*M evaluation calls, where TT is number of diffusion timesteps and MM is the number of perturbations per timestep. Thus if evaluation is expensive relative to the difusion process, then this method might not make sense. Rather, it would be better to save the evaluation compute for the end and generate TMT*M many diffusion samples and then choose the argmax by evaluating each generated sample.

However, if evaluation isn't terribly expensive relative to the diffusion process, then it might make intuitive sense to add guidance at each intermediate step rather than spam generating a bunch of outputs and hope that some would end up good. For example, with random walks or brownian motion, they typically end up staying up in some ball not too far from the starting mean because it'll end up walking back on itself, whereas if you can bias the walk at each step then you might be able to guide the walk outside of the ball to a particular desired direction.

Conclusion

Obviously there is a lot more left to investigate. Possibly the most valuable thing is simply the base intuition of the method: if we can't do gradient descent but we have access to compute, then through sampling we can search for the better "local minima" regardless of the differentiability of the guidance function.

The hackathon source code is here.

There is also a recording of the presentation I gave.

References

  • (Gruver 2023) Protein Design with Guided Discrete Diffusion

  • (Bansal 2023) Universal Guidance for Diffusion Models

  • (Ho 2020) Denoising Diffusion Probabilistic Models

  • (Austin 2021) Structured Denoising Diffusion Models in Discrete State-Spaces