- Published on
(Shower Thoughts) Stochastic Search is All You Need—Guidance for Discrete Diffusion with Non-Differentiable Objectives
Quick note on the blog
I have not added to the blog for quite sometime, which I find kind of a shame. Part of the bottleneck is that in my head I've kept this high standard for what is acceptable—it should be very detailed, in depth, while explaining the main ideas clearly. While I hope to keep this standard for most of my research blog posts, I also think there is value in sharing ideas that might be off the cuff and less polished. This is one example of that: I wrote this write up to summarize the idea I had for a hackathon project.
Introduction
Diffusion is a powerful generative model that learns the data distribution to generate samples. However, in many settings, the space of possible outputs is too large---thus there is a need to control the generation process to get a desired output. Luckily diffusion models are by nature controllable---at inference time, the diffusion model at each time step outputs a sample in the output data space, allowing it to be plugged in to some guidance or value function to get signal from. This is the idea of classifier guidance, which is powerful because it allows for controllable generation with no additional training (the guidance happens at inference time). Moreover, the guidance function can be easily swapped in and out to tailor to different settings. While classifier guidance works out quite nicely in most settings (ex: gaussian diffusion models incorporate the gradient of the differentiable classifier into gaussian sampling mean), it's not so obvious in many realistic settings, such as when the data is discrete (gradients are tricky to get with discrete inputs) or if the guidance function is non-differentiable and not easily approximated (ex: running some software or simulation code). This post summarizes a naive idea to work around this non-differentiable and discrete guidance setting: if you can't do gradient descent, why not find through stochastic search the perturbation that leads to higher values?
What would we like to do
In the continuous setting with differentiable guidance functions, it is easy to add guidance to a diffusion process at inference time. At a high level, we can simply modify the diffusion process by incorporating the gradient of the guidance function. At heart, we have the following equation
Under particular assumptions, there is a direct connection between diffusion models and score-based models; the diffusion model is essentially learning the second term and the reverse diffusion process moves the generated sample to regions of higher data likelihood by adding the score . Now we can also ensure that the generated sample is moving to regions with higher value; given some classifer or guidance function we can also adds its score . This obviously requires the guidance funtion to be differentiable. In the typical gaussian continuous diffusion setting, the math works out nicely and it's enough to directly add the score of our guidance function to the gaussian mean parameter:
Our work around
But if our guidance function is non-differentiable, then we can't just calculate the score through a gradient calculation. For example, with protein sequence generation, the diffusion model can output at each reverse timestep protein sequences (discrete data) which we want to guide. Now suppose that we want our generated proteins to bind to some other protein and we have this software that can take in a sequence, run some simulations, and output its binding affinity score.
Here we have the guidance function at our disposal---but we can't get a gradient from it. So how can we guide our discrete sequences?
It might be difficult to work directly with discrete sequences (for example it's hard to tell which sequences are close to a given sequence) so we can instead consider the continuous logits that are decoded to a sequence.
Universal Guidance for Diffusion Models (Bansal et al. 2023) add guidance to images by doing the typical classifier guidance setup and also doing a local optimization around the denoised diffusion output. Concretely, if at some timestep our diffusion outputs some clean image , it uses gradient descent to do this local search and calculate
Now if our guidance function is non-differentiable then we can't do gradient descent but this motivates a similar idea: why not just try to do local bruteforce search around your given diffusion output and return the best performing perturbation .
Specifically, one naive way of doing this is for any output to stochastically sample (ex: from gaussian noise ) and generate many perturbations and return the argmax. This is a simple first step: further improvements could be to ensure that our perturbations are diverse (but in high dimensional space, we assume with high probability that random samples don't get cloned) or to possibly do something smarter by having the argmax collapse the perturbations not in one-step but across multiple steps, so we search through a tree instead through a list of perturbations.
Experimental results on protein sequences
For the hackathon, I built on an existing repository that contained discrete diffusion code for protein sequences. The particular discrete diffusion model that I trained was a masked langauge model discrete diffusion model, which is essentially a BERT model trained to unmask masked tokens from some corrupted sequence. I also trained a value function to optimize a protein characteristic called "SASA". While it would be amazing if the value function was not a neural network (because the main utility for doing stochastic search is to work with non-differentiable guidance functions), the method holds for both differentiable and non-differentiable functions.
For the experiment, we add guidance at each timestep by taking the argmax of where where is the logits that become decoded into sequence For each we create perturbations. With an eyeball check, these logits typically have magnitude larger than 10 so this perturbation is not unreasonable.
Model | SASA Value Score |
---|---|
With guidance | 1.286 |
Without guidance | 0.829 |
(Results averaged across batch size 1024)
This doesn't say much: with guidance, we take the argmax across multiple perturbations so we expect the value to increase.
To do a robust and fair comparison, we should consider the alternative---generating samples and then evaluating the samples to take the argmax. The guidance process adds evaluation calls, where is number of diffusion timesteps and is the number of perturbations per timestep. Thus if evaluation is expensive relative to the difusion process, then this method might not make sense. Rather, it would be better to save the evaluation compute for the end and generate many diffusion samples and then choose the argmax by evaluating each generated sample.
However, if evaluation isn't terribly expensive relative to the diffusion process, then it might make intuitive sense to add guidance at each intermediate step rather than spam generating a bunch of outputs and hope that some would end up good. For example, with random walks or brownian motion, they typically end up staying up in some ball not too far from the starting mean because it'll end up walking back on itself, whereas if you can bias the walk at each step then you might be able to guide the walk outside of the ball to a particular desired direction.
Conclusion
Obviously there is a lot more left to investigate. Possibly the most valuable thing is simply the base intuition of the method: if we can't do gradient descent but we have access to compute, then through sampling we can search for the better "local minima" regardless of the differentiability of the guidance function.
The hackathon source code is here.
There is also a recording of the presentation I gave.
References
(Gruver 2023) Protein Design with Guided Discrete Diffusion
(Bansal 2023) Universal Guidance for Diffusion Models
(Ho 2020) Denoising Diffusion Probabilistic Models
(Austin 2021) Structured Denoising Diffusion Models in Discrete State-Spaces