Published on

(Research Idea) Discrete Diffusion for Inverse Protein Folding


🔗Latek PDF

Introduction

Deep learning made a huge impression on the science community when Deepmind released AlphaFold, a deep network that predicts a protein's structure from its sequence. AlphaFold outperformed all other protein folding methods and stands as a baseline to be beaten. The converse Inverse Protein Folding problem, predicting a protein's sequence from its structure, is possibly more challenging as there are many sequences that can fold to a particular protein structure. Both science based methods and deep learning methods have been proposed for inverse protein folding. One particularly promising and quite novel approach is the use of denoising diffusion models.

Background: (denoising) diffusion models

For a more extensive treatment, Lilian Weng has a great introductory post on diffusion models (Weng et al. 2021). Calvin Luo also provides great intuition and part of his insights are included below (Luo et al. 2022).

In short a diffusion model (more precisely, denoising diffusion probabilistic model) is a parameterized Markov chain that learns to generate samples that match the input data. The diffusion model has a forward process which takes in a true data sample x0q(x)x_{0}\sim q(x) and progressively corrupts it over many timesteps it with noise ; notationally the data x0x_{0} undergoes many forward transitions q(xtxt1)q(x_{t}|x_{t-1}), which for continuous data is commonly an isotropic gaussian q(xtxt1)=N(1βtxt1,βtI)q(x_{t}|x_{t-1})=N(\sqrt{1-\beta_{t}}x_{t-1},\beta_{t}I) (here βt\beta_{t} defines the variance and can either be fixed as hyper-parameters or learned by reparameterization). After many forward transitions, the original data ends up being corrupted to just noise.

The diffusion model learns a reverse process to undo this forward noising process; in this way, it can take a sample drawn from noise and denoise it to a sample that lies in the original data distribution q(x)q(x). For example, if the forward transition is gaussian, our diffusion model would learn the mean and covariance of a reverse gaussian transition, i.e pθ(xt1xt)=N(μθ(xt,t),Σθ(xt,t)).p_{\theta}(x_{t-1}|x_{t})=N(\mu_{\theta}(x_{t},t),\Sigma_{\theta}(x_{t},t)). Once we learn the reverse process, we can take a randomly sampled noise xTx_{T} and then apply the learned reverse transitions many times to get to a x0^\hat{x_{0}} which hopefully lies in the true data distribution q(x)q(x).

How do we learn the reverse process? Ultimately the distribution that results from the reverse process should match the distribution of the real data. An intuitive objective is to maximize the likelihood of the real data x0x_{0} under the learned reverse distribution. Because maximizing the likelihood directly is intractable, we maximize a variational lower bound of the log likelihood. We can make this a cost function by taking the negative and instead minimize the negative of the lower bound of the log likelihood Eq(x0)[logpθ(x0)]-\mathbb{E}_{q(x_{0})}[logp_{\theta}(x_{0})] which spelled out is:

Lvb=Eq(x0)[DKL(q(xTx0)p(xT))+t=2TEq(xtx0)[DKL(q(xt1xt,x0)pθ(xt1xt))]Eq(x1x0)[logpθ(x0x1)]]L_{vb}=\mathbb{E}_{q(x_{0})}\Big[D_{KL}(q(x_{T}|x_{0})||p(x_{T}))+\sum_{t=2}^{T}\mathbb{E}_{q(x_{t}|x_{0})}[D_{KL}(q(x_{t-1}|x_{t},x_{0})||p_{\theta}(x_{t-1}|x_{t}))]-\mathbb{E}_{q(x_{1}|x_{0})}[logp_{\theta}(x_{0}|x_{1})]\Big]

With enough timesteps, the forward process q(xtxt1)q(x_{t}|x_{t-1}) corrupts the data such that q(xTx0)q(x_{T}|x_{0}) is essentially noise; thus the first term, the KL divergence between q(xTx0)q(x_{T}|x_{0}) and the assumed prior p(xT)p(x_{T}) of noise, is basically zero. The last term represents a reconstruction loss where we try to maximize the likelihood of predicting the original data given the first step latent. With the middle term, we see that at each time step we try to match our predicted transition pθ(xt1xt)p_{\theta}(x_{t-1}|x_{t}) with the actual reverse transition q(xt1xt,x0)q(x_{t-1}|x_{t,}x_{0}). (Note: we also condition on x0x_{0} to make this calculation tractable, as explained in (Sohl-Dickstein 2015)).

To maximize this lowerbound (or minimize the cost function), we would first parametrize the transition q(xt1xt,x0)q(x_{t-1}|x_{t},x_{0}) to propely construct our estimate reverse transition pθ(xt1xt)p_{\theta}(x_{t-1}|x_{t}). In the gaussian setting defined in (Ho et al. 2020), q(xt1xt,x0)q(x_{t-1}|x_{t},x_{0}) ends up being parameterized as an isotropic gaussian N(μ(xt,x0,t),σt2I)N(\mu(x_{t},x_{0},t),\sigma_{t}^{2}I). Based on this parameterization of the reverse transition, we try to learn an isotropic gaussian pθ(xt1xt)=N(uθ(xt,x0,t),Σt)p_{\theta}(x_{t-1}|x_{t})=N(u_{\theta}(x_{t},x_{0,}t),\Sigma_{t}) to match this transition and minimizes the KL divergence of the distributions. It turns out that our objective reduces to

Eq[12σt2μ(xt,x0)μθ(xt,t)2]+C\mathbb{E}_{q}[\frac{1}{2\sigma_{t}^{2}}||\mu(x_{t},x_{0})-\mu_{\theta}(x_{t},t)||^{2}]+C

i.e we try to predict the reverse process mean at some time step tt.

(Ho et al. 2020) shows that it is effective to train a model ϵθ\epsilon_{\theta} (i.e a neural network) to learn the added noise rather than the mean and covariance of the diffused data. Thus they used the simplified objective function: L(θ)=Et,x0,ϵ[ϵϵθ(xt,t)2]L(\theta)=\mathbb{E}_{t,x_{0},\epsilon}[||\epsilon-\epsilon_{\theta}(x_{t},t)||^{2}].

Why does this simplified objective make sense?

In the setting of (Ho et al. 2020), the noise ϵN(0,I)\epsilon\sim N(0,I) is drawn from a Gaussian that is independent of both the time tt and the data xtx_{t} . So we might wonder why we are trying to predict noise as a function of xtx_{t} and tt . While this noise prediction objective may seem counter-intuitive, in fact it derives from the objective of predicting the forward process mean at some time step tt (which is a function of xtx_{t} and tt). With reparameterization and some simplifications (spelled out in (Ho et al. 2020)), the forward process mean prediction objective becomes the noise prediction objective.

Inverse protein folding

The problem of inverse protein folding is to predict a protein's sequence from its structure. Intuitively, our prediction model should have the property of roto-translational invariance. Given some protein structure, even if we apply a bunch of translations or rotations to it, it should still map to the same sequence that defines it.

How do we have roto-translation invariance?

One strategy is to take our protein structure yy and encode it into an roto-translation invariant representation. If yRN×3y\in\mathbb{R}^{N\times3} is just the 3D coordinates of each residue then it is not rotation or translation invariant- but if we instead use geometric variables such as distances or angles we can get an invariant representation of the structure. For example, (Wu et al. 2022) use "inter-residue angles" or angles that represent the position of the next residue relative to the current residue to get a invariant representation. Additionally, instead of manually crafting an invariant representation, we can use a learned network to extract an invariant latent representation. For example, (Hsu et al. 2022) use a geometric vector perceptron graph neural network (GVP-GNN) encoder-decoder to get a geometric invariant latent vector from the protein structure.

Difussion models for inverse protein folding

Learning to generate sequences conditioned on structure data

We know that a diffusion model can learn how generate realistic samples that approximately lie in the real data distribution. Specifically, our diffusion model (which is just a neural network that predicts the reverse process) can be conditioned on some information, specifically the protein structure. Then, we can simply pass some protein sequence sampled from noise to the reverse diffusion process while conditioning on the protein structure to generate a realistic protein sequence corresponding to the real protein structure.

To unpack this, say we have some real protein sequence and structure data, {x0i,y0i}i[m]\{x_{0}^{i},y_{0}^{i}\}_{i\in[m]} where xx is the protein sequence and yy is the structure. A conditional diffusion model that diffuses the protein sequence minimizes the objective function (here we use the simplified noise objective function)

E[ϵxϵθ(xt,y0,t)2]\mathbb{E}[||\epsilon^{x}-\epsilon_{\theta}(x_{t},y_{0},t)||^{2}]

where y0y_{0} is the real data protein structure and ϵx\epsilon^{x} is the noise on the sequence data. We see that our diffusion model is learning the noise conditional on the current noisy sequence xtx_{t} and the actual protein structure y0y_{0} (and the timestep tt). To make this explicit, under the squared objective we learn the mean E[ϵxxt,y0,t]\mathbb{E}[\epsilon^{x}|x_{t},y_{0},t].

How can we implement this in practice?

I was initially confused on how to learn this conditional distribution with multiple inputs in practice. But it's no different than the toy example of predicting house prices conditional on some room sizes and house characteristics; in that example, we could concatenate house sizes with house characteristics to create our input and then simply pass that to the model. Similarly, we could simply concenate the two conditioning inputs xtx_{t} and y0y_{0}: for example if we used a transformer we would pass in a sequence whose first chunk is the xtx_{t} sequence information and the second chunk is the y0y_{0} structure information and then predict noise with a regression head.

Conditioning on information has been shown to possibly lead to better diffusion performance; this is the central thesis of Classifier Free Guidance introduced by (Ho et al. 2022). The idea is that the diffusion model learns a linear combination of both a conditional and unconditional noise estimate

ϵ~θ(zλ,c)=(1+w)ϵθ(zλ,c)wϵθ(zλ)\tilde{\epsilon}_{\theta}(z_{\lambda},c)=(1+w)\epsilon_{\theta}(z_{\lambda},c)-w\epsilon_{\theta}(z_{\lambda})

where cc is the conditioning information and ww toggles how much to incorporate the conditional information.

Interestingly, (Ho et al. 2022) learn both the conditional and unconditional noise estimate using the a single neural network for simplicity. To learn the unconditional noise estimate, they simply input a null token c=c=\varnothing as the conditional input so the network really is only conditioned on the information zλz_{\lambda}.

Learning to generate sequence and structures jointly

Instead of learning to generate realistic protein sequences conditioned on its structure information, we can instead learn to generate realistic protein sequence and structure pairs. Now, instead of just reverse diffusing or denoising a noisy sampled sequence, we can reverse diffuse a noisy protein sequence and structure pair to get a realistic sequence and structure pair. Under our objective we learn the noise E[ϵx,ϵyxt,yt,t]\mathbb{E}[\epsilon^{x},\epsilon^{y}|x_{t},y_{t},t] where we denote ϵx\epsilon^{x} as the noise on the sequence data and ϵy\epsilon^{y} as the noise on the structure data (Bao et al. 2023). Contrary to before, instead of conditioning on the intial true protein structure, we are conditioning on the noisy sampled structure yty_{t} corresponding to that timestep.

Speeding up with a latent vector (latent diffusion)

Latent diffusion (Rombach et al.
2022)

One problem with diffusion models is that the forward and reverse process can be very time and compute expensive. Instead of diffusing directly on the input data, the latent diffusion approach is to compress the input into a smaller latent representation using an autoencoder (Rombach et al. 2022). Diffusion is run on the latent representation (i.e we forward diffuse the latents, learn the reverse diffusion process on the latents, then sample a random noisy latent to reverse diffuse into a realistic latent that lies in the latent data distribution) and then decode the generated latent with the autoencoder to generate a realistic original input sample.

Diffusion with non-continuous data (e.g protein sequence)

In learning to generate a protein sequence conditioned on the structure or in learning to generate a protein sequence and structure pair, diffusion needs to forward diffuse or noise a sequence. But a protein sequence is a sequence of discrete variables, where each variable takes one of twenty protein residue categories. We can't just simply add gaussian noise to our discrete variables. One solution is to force the categorical protein residue to be embedded as a continuous variable, but this could be an unnatural approach. Rather, we can have our transitions be tailored for discrete variables. This is the idea of Discrete Denoising Diffusion Probabilitistic Models, introduced by (Austin et al. 2021).

One example of a transition function for discrete spaces is the uniform transition: it stays in the same state with a certain probability pp and then the remaining probability mass 1p1-p is uniformly distributed to all KK possible states. Another example is a discretized gaussian, where the gaussian function is truncated and discretized such that it mimics the continuous gaussian transition. To define our transition function more generally, we can consider a transition matrix [Qt]ij=q(xt=jxt1=i)[Q_{t}]_{ij}=q(x_{t}=j|x_{t-1}=i) where the (i,j)(i,j) entry defines the transition from state value ii to state value jj. Notationally, our categorical variable xt1R1×Kx_{t-1}\in\mathbb{R}^{1\times K} is a one hot row vector (i.e 1 for the corresponding protein residue, 0 elsewhere) and we can define our transition function as a categorical distribution defined by q(xtxt1)=Cat(xt1Qt)q(x_{t}|x_{t-1})=Cat(x_{t-1}Q_{t}) where QtRK×KQ_{t}\in\mathbb{R}^{K\times K} is our transition matrix. For our markov chain to be well-defined, we need to define QtQ_{t} such that for tt large the distribution defined by Q1Q2QtQ_{1}Q_{2}\dots Q_{t} converges to a stationary distribution.

As far as I know, there is no heuristic that determines which transition functions would allow the diffusion model to perform better generation. In the case of generating images, (Austin 2021) show that the discretized gaussian performs best relative to the uniform transition and also an absorbing transition (defined later below).

Discrete diffusion objective function

In the case of discrete diffusion, it is clearer to go back to the more complete spelled out varational log likelihood objective.

(Austin et al. 2021) use the loss function that is a sum of the variational likelihood objective and a cross entropy term Lvb+λEq(x0)Eq(xtx0)[logp~θ(x0xt)]L_{vb}+\lambda\mathbb{E}_{q(x_{0})}\mathbb{E}_{q(x_{t}|x_{0})}[-log\tilde{p}_{\theta}(x_{0}|x_{t})] or in full

Eq(x0)[DKL(q(xTx0)p(xT))+t=2TEq(xtx0)[DKL(q(xt1xt,x0)pθ(xt1xt))]Eq(x1x0)[logpθ(x0x1)]]\mathbb{E}_{q(x_{0})}\big[D_{KL}(q(x_{T}|x_{0})||p(x_{T}))+\sum_{t=2}^{T}\mathbb{E}_{q(x_{t}|x_{0})}[D_{KL}(q(x_{t-1}|x_{t},x_{0})||p_{\theta}(x_{t-1}|x_{t}))]-\mathbb{E}_{q(x_{1}|x_{0})}[logp_{\theta}(x_{0}|x_{1})]\big] +λEq(x0)Eq(xtx0)[logp~θ(x0xt)]+\lambda\mathbb{E}_{q(x_{0})}\mathbb{E}_{q(x_{t}|x_{0})}[-log\tilde{p}_{\theta}(x_{0}|x_{t})]

To deal with the main term, the divergence between our predicted reverse process pθ(xt1xt)p_{\theta}(x_{t-1}|x_{t}) and the reverse process q(xt1xt,x0)q(x_{t-1}|x_{t},x_{0}), we need to parametrize our predicted reverse transition pθ(xt1xt)p_{\theta}(x_{t-1}|x_{t}). (Austin 2021) parameterize pθ(xt1xt)p_{\theta}(x_{t-1}|x_{t}) as x0~q(xt1,xtx0~)p~θ(x0~xt)\sum_{\tilde{x_{0}}}q(x_{t-1},x_{t}|\tilde{x_{0}})\tilde{p}_{\theta}(\tilde{x_{0}}|x_{t}) where we sum over all possible output sequences x0~\tilde{x_{0}}.

(Note: the predicted function is denoted as p~\tilde{p} rather than pp to show that p~θ(x0xt)\tilde{p}_{\theta}(x_{0}|x_{t}) is a neural network that predict the distribution of a final diffused output x0x_{0} given xtx_{t}. pθ(xt1xt)p_{\theta}(x_{t-1}|x_{t}) defined above incorporates p~θ(x0~xt)\tilde{p}_{\theta}(\tilde{x_{0}}|x_{t}) to predict the diffused output one step away.)

The real kicker: masked language models as discrete diffusion models

Masked language models, like BERT, learn to generate text in a sequence based on a given context. Specifically, in training, input sequences are randomly corrupted (i.e some tokens become [MASK]) and the language model tries to predict and recover the masked tokens given the unmasked context. (Austin et al. 2021) in fact prove that masked language models under a cross entropy objective are discrete diffusion models. Specifically, they show that masked language models are an absorbing-state discrete diffusion model. In an absorbing state discrete diffusion model, each token either stays in the same state or transitions to an absorbing state (e.g [MASK]) to get stuck there forever. If we say that the probability of transitioning to the absorbing state is βt\beta_{t} then we can define the transition probability q(xt=jxt1=i)q(x_{t}=j|x_{t-1}=i) by the matrix entries

[Qt]ij={1i=j=m1βti=jmβtj=m,im[Q_{t}]_{ij}=\begin{cases} 1 & i=j=m\\ 1-\beta_{t} & i=j\ne m\\ \beta_{t} & j=m,i\ne m \end{cases}

Our transition matrix QtQ_{t} is such that for tt large, Qˉt=Q1Q2...Qt\bar{Q}_{t}=Q_{1}Q_{2}...Q_{t} converges to the stationary distribution where all the probability mass is on the absorbing state [MASK].

Implement masked language models (MLM) for inverse protein folding

To train the MLM diffusion model, (Anand et al. 2022) randomly mask a fraction tt of the protein sequence and have the model predict the masked parts of sequence. The mask fraction tt is linearly interpolated from [0,1] and so the language model learns to predict from more and more noise. Then, to generate a predicted sequence, it starts with a completely masked sequence xTx_{T} (the equivalent of pure noise). Here we are sampling our noisy sequence from the stationary distribution that results from the forward diffusion process, which has all probability mass on the [MASK] state. To progressively denoise and reverse diffuse our sequence, at each round tt they mask the generated denoised xtx_{t} with a probability tT\frac{t}{T} and predict the masked protein residues to generate the next xt1x_{t-1}, repeating the process again to get back to a generated x^0\hat{x}_{0} protein sequence.

Order-agnostic autoregressive models (OARM) as discrete diffusion models

It turns out the other kind of models typical in language models, autoregressive models, can also be defined as a discrete diffusion model. (Hoogeboom et al. 2022) introduce order-agnostic autoregressive diffusion models, which is derived from autoregressive models and can learn to generate in any order.

The typical autoregressive likelihood objective is precisely logp(x)=t=1Dlogp(xtx<t)logp(x)=\sum_{t=1}^{D}logp(x_{t}|x_{<t}) where DD is the length of the sequence. In the autoregressive setting, earlier parts of the sequence contain information that is used to predict the later elements. However, in cases where there is not a pre-defined or natural ordering of the sequence, we can instead use an order-agnostic autoregressive objective, where we first randomly permute the sequence before prediction. Our likelihood objective has the following lowerbound

logp(x)=logEσUnif[p(xσ)]EσUnif[log(p(xσ)]logp(x)=log\mathbb{E}_{\sigma\sim Unif}[p(x|\sigma)]\geq\mathbb{E}_{\sigma\sim Unif}[log(p(x|\sigma)]

where the inequality is due to Jensen's inequality with a concave function. We can further re-express our lowerbound as

=EσUnif[log(p(xσ(t)x<σ(t))]=EσUnif[t=1Dlogp(xσ(t)xσ(<t))]=EσUnifDEtUnif([D])[logp(xσ(t)xσ(<t))]=\mathbb{E}_{\sigma\sim Unif}[log(\prod p(x_{\sigma(t)}|x_{<\sigma(t)})]=\mathbb{E}_{\sigma\sim Unif}[\sum_{t=1}^{D}logp(x_{\sigma(t)}|x_{\sigma(<t)})]=\mathbb{E}_{\sigma\sim Unif}D\mathbb{E}_{t\sim Unif([D])}[logp(x_{\sigma(t)}|x_{\sigma(<t)})]

This can be expressed as (full details in Hoogeboom et al. 2022)

=DEtUnif([D])[EσUnif[1Dt+1kσ(t)logp(xkxσ(<t))]]=D\mathbb{E}_{t\sim Unif([D])}[\mathbb{E}_{\sigma\sim Unif}[\frac{1}{D-t+1}\sum_{k\in\sigma(\geq t)}logp(x_{k}|x_{\sigma(<t)})]]

If we notate Lt:=EσUnif[1Dt+1kσ(t)logp(xkxσ(<t))]\mathfrak{L}_{t}:=\mathbb{E}_{\sigma\sim Unif}[\frac{1}{D-t+1}\sum_{k\in\sigma(\geq t)}logp(x_{k}|x_{\sigma(<t)})] our likelihood lower bound objective is: DEtUnif([D])[Lt]D\mathbb{E}_{t\sim Unif([D])}[\mathfrak{L}_{t}].

We see in the Lt\mathfrak{L}_{t} term that the first t1t-1 terms in some randomly permuted sequence are used the predict the rest. The way that we optimize Lt\mathfrak{L_{t}} in practice is with a BERT objective where we mask the tail Dt+1D-t+1 elements of the permuted sequence and predict them using the head unmasked elements.

Training our model in practice

To optimize our likelihood bound, we need to predict the probability distribution of the masked elements. To do this we train a model taht predicts a probability distribution over each element in the sequence. Notationally, let x=[1,K]DXx=[1,K]^{D}\in\mathcal{X} be our sequence of kk categories and let our single neural network f:XRD×kf:\mathcal{X}\to\mathbb{R}^{D\times k} be a model that outputs a probability distribution over the categories for each element in the sequence. To predict the probability distributions of all of the masked tokens in the sequence, given a mask mm that corrupts certain tokens to [MASK] and leaves the rest unchanged, we pass in our masked sequence f(xm)f(x\odot m) to our network ff. Then, we can extract from the output the predicted probability vectors for the mask tokens.

Sampling our OARM model to generate sequences

Similar to the absorbing diffusion model case, to generate a predicted output sequence we first start with an all masked sequence and then use our model to predict the distribution of elements from which we can sample the next sequence. At each timestep tt as tt goes from T1T\to1 we mask the sequence less and less; whereas in (Anand et al. 2022) the probability of masking decreased, here we mask the tail tt tokens.

To write the sampling process explicitly:

  • Sample some random permutation σUnif\sigma\sim Unif, start with an all masked sequence xtx_{t}

  • for ii in [1,D][1,D]:

    • Create our mask mm which masks everything but the first i1i-1 tokens in the permuted sequence

    • Sample an xtx_{t}' which is drawn from the predicted probability distribution outputed by our trained network xtCat(f(mxt))x_{t}'\sim Cat(f(m\odot x_{t})) by inputting our masked sequence

    • Combine the non-masked parts from xtx_{t} or mcxtm^{c}\odot x_{t} and the predicted masked parts from mxtm\odot x_{t}' to generate our next timestep xt1=mcxt+mxtx_{t-1}=m^{c}\odot x_{t}+m\odot x_{t}'

Proposed method: diffusion sequence impainting

The inverse folding problem is to predict a protein sequence given a protein structure; thus, a natural approach may directly condition the diffusion model on the protein structure to generate a protein sequence. For example, (Anand et al. 2022) do this with an absorbing discrete diffusion model and (Yi et al. 2023) do this with a graph diffusion model.

However instead of directly conditioning the diffusion model on the protein structure, we can follow the Repaint image impainting framework by (Lugmayr et al. 2022) and condition on noised protein structure data. To provide some intuition, in the image impainting problem part of the image is missing/unknown and we want to use the known parts of the image to impaint or fill in the missing region. In our setting, we can image that a protein sequence and structure pair makes up an image and the protein sequence is the missing information that we want to "impaint". One way to do this as mentioned before is to directly condition on the known data to predict our missing data.

Alternatively we can imagine noising the entire image and learn to reverse diffuse the entire image while conditioning on noisy known data to ensure that our final generated image contains the correct known region and an impainted unknown region that corresponds to the known region. (We can imagine diffusion on the entire image as a joint diffusion on the known and unknown regions, where the noise ϵunknown,ϵknown\epsilon^{unknown},\epsilon^{known} are identifical in distribution). In our setting, we would use noised protein structure data to condition our diffusion model to generate a corresponding protein sequence.

Repaint method (Lugmayr et al.
2022)

Repaint framework

Repaint differs from normal diffusion on the entire image as it conditions on a preset noisy version of the known region and it also employs a resampling step. Notationally, given some ground truth image xx, unknown pixels mxm\odot x, and known pixels (1m)x(1-m)\odot x, they define the one-step reverse diffused output xt1x_{t-1} to be a combination of the unknown pixel region of the reversed diffused image (as in typical diffusion) and the known pixel region of a preset corrupted ground truth image x0x_{0}.

xt1knownN(αˉtx0,(1αtˉ)I)x_{t-1}^{known}\sim N(\sqrt{\bar{\alpha}}_{t}x_{0},(1-\bar{\alpha_{t}})\boldsymbol{I})

xt1unknownN(μθ(xt,t),Σθ(xt,t))x_{t-1}^{unknown}\sim N(\mu_{\theta}(x_{t},t),\Sigma_{\theta}(x_{t},t))

xt1=mxt1known+(1m)xt1unknownx_{t-1}=m\odot x_{t-1}^{known}+(1-m)\odot x_{t-1}^{unknown}

What (Lugmayr et al. 2022) remark is that the final predicted impainted region x0unknownx_{0}^{unknown} doesn't harmonize well with the rest of the image even though it seems to be integrating the conditioning known data. Thus they use a resample step with some jump size jj, where the reverse diffusion output xt1x^{t-1} is re-noised jj times to xt1+jx^{t-1+j} by undergoing a series of jj noise steps defined by xtN(1βtxt1,βtI)x_{t}\sim N(\sqrt{1-\beta_{t}}x_{t-1},\beta_{t}\boldsymbol{I}).

Inpainting for protein sequence recovery

We can adapt the inpainting method to our setting by considering the joint distribution of the protein structureand the protein sequence and we wish to impaint the missing protein sequence data.

For simplicity, we first consider the continuous diffusion setting. We follow (Mcpartlon 2023) and use an auto-encoder to extract a continuous latent which we can run diffusion on. Specifically, we can train an autoencoder to encode our graph structure to some latent vector zknownz^{known} and we can train another autoencoder to encode our protein sequence zunknownz^{unknown}. (Here we can consider using a similar autoencoder such that the latent representation lies in the same dataspace). We then follow (Lugmayr et al. 2022) and train a diffusion model on the latent data z:=(zknown,zunknown)z:=(z^{known},z^{unknown}) using gaussian noise and the typical diffusion objective function.Finally, we can recover an impainted latent sequence z0unknownz_{0}^{unknown} which we can decode using our sequence auto-encoder to generate our predicted protein sequence.

However we might consider it unnatural to encode a discrete sequence into a continuous latent. Thus we can use a discrete diffusion model for our protein sequence generation. For example, as in (Yi et al. 2023), we can use a discrete diffusion model based on a doubly stochastic transition matrix derived from empirical sequence mutation data from BLOSUM. This ensures our transition matrix is grounded in empirical data. For the protein structure data which is naturally continuous, we can either use some hand-crafted geometric representation such as (Wu et al. 2022) or follow (Hsu et al. 2022) and use a latent representation that comes from the (Jing et al. 2021) graph neural network GVP-GNN to get a geometrically invariant representation of the protein structure.

For the diffusion process we follow the Repaint method described above. The continuous protein structure data follows the typical gaussian setting:

xt1knownN(αˉtx0,(1αtˉ)I)x_{t-1}^{known}\sim N(\sqrt{\bar{\alpha}}_{t}x_{0},(1-\bar{\alpha_{t}})\boldsymbol{I})

To get our reverse diffused sequence data xt1unknownx_{t-1}^{unknown}, we learn the reverse process pθ(xt1unknownxt)p_{\theta}(x_{t-1}^{unknown}|x_{t}), which following (Austin et al. 2021) we parametrize as x0~q(xt1unknown,xtx0~unknown)p~θ(x0~unknownxt)\sum_{\tilde{x_{0}}}q(x_{t-1}^{unknown},x_{t}|\tilde{x_{0}}^{unknown})\tilde{p}_{\theta}(\tilde{x_{0}}^{unknown}|x_{t}).

Thus we train a network to predict the distribution of the final denoised sequence output p~θ(x0~unknownxt)\tilde{p}_{\theta}(\tilde{x_{0}}^{unknown}|x_{t}) conditioned on the previous noisy output, which we note is xt=mxtknown+(1m)xtunknownx_{t}=m\odot x_{t}^{known}+(1-m)\odot x_{t}^{unknown} , the concatenation of the noised protein structure data mxtknownm\odot x_{t}^{known} and the denoised sequence data (1m)xtunknown(1-m)\odot x_{t}^{unknown}. (Note: we keep the notation of (Lugmayr 2022)- here mxtknownm\odot x_{t}^{known} is the entire noisy protein structure data and (1m)xtunknown(1-m)\odot x_{t}^{unknown} is the entire noisy protein sequence data). The training objective is the lower likelihood bound objective for discrete diffusion models stated previously.

To generate our predicted impainted protein sequence, we would sample a pure noise sequence from a uniform distribution xTunknownCat(Unif)x_{T}^{unknown}\sim Cat(Unif) and then run the reverse diffusion process with our trained diffusion model. (Here we sample from the uniform distribution because a doubly stochastic trasition matrix converges to a uniform stationary distribution.)

Choice of discrete diffusion model

So far we have seen an absorbing state discrete diffusion model, a more general discrete diffusion model defined by some transition matrix QtQ_{t}, and also a order-agnostic autoregressive discrete diffusion model. (Anand et al. 2022) use an absorbing discrete diffusion model, (Yi et al. 2023) use a discrete diffusion model where the transition matrix comes from BLOSUM. (Alamdari et al. 2023) compare the performance of a order-agnostic autoregressive discrete diffusion model, a BLOSUM based discrete diffusion model, and an absorbing state discrete diffusion model and state that the order-agnostic autoregressive model performs best for sequence generation for protein design. Following this, we might decide to use instead an order-agnostic autoregressive model.

Other ideas: joint distribution and following classifer free guidance

The inpainting method by (Lugmayr et al. 2022) can be seen as an intermediate method between a direct conditioning diffusion model where we learn the noise on the sequence E[ϵseqxtseq,x0struct]\mathbb{E}[\epsilon^{seq}|x_{t}^{seq},x_{0}^{struct}] conditioned on the true protein structure and a joint distribution diffusion model where we learn E[ϵseq,ϵstructxtseq,xtstruct]\mathbb{E}[\epsilon^{seq},\epsilon^{struct}|x_{t}^{seq},x_{t}^{struct}] both the noise on the sequence and the noise on the structure.

We could consider the joint diffusion setting where we do not explicitly condition on the structure information and instead jointly reverse diffuse a pure noise sequence and structure sample to generate a realistic sequence and structure pair. However we would need to construct some way of enforcing the generated protein structure to match the true input protein structure for our inverse folding setting.

(Ho et al. 2022) in their work on classifier free guidance show better results when combining both a condition and unconditional model. Following this, we might consider learning both a conditional and unconditional discrete diffusion model to generate sequences xtseqx_{t}^{seq}. For example, our unconditional model parameterized by weights θ1\theta_{1} can learn the reverse process pθ1(xtseqxtseq)p_{\theta_{1}}(x_{t}^{seq}|x_{t}^{seq}) by training a network to predict the distribution of the final denoised sequence simply based off the previous denoised sequence, i.e p~θ1(x0~seqxtseq).\tilde{p}_{\theta_{1}}(\tilde{x_{0}}^{seq}|x_{t}^{seq}). For the conditional model parameterized by weights θ1\theta_{1} , we can encode our graph structure information into some geometric invariant representation x0structx_{0}^{struct} and then learn the reverse diffusion process pθ2(xtseqxtseq,x0struc)p_{\theta_{2}}(x_{t}^{seq}|x_{t}^{seq},x_{0}^{struc}) by training a network to predict the distribution of the final denoised sequence given the previous denoised sequence and the true protein structure, i.e p~θ2(x0~seqxtseq,x0struc).\tilde{p}_{\theta_{2}}(\tilde{x_{0}}^{seq}|x_{t}^{seq},x_{0}^{struc}).

To generate our sequences, we can linearly combine the predicted distributions of the two models by some weighting parameter ww to get the predicted distribution wpθ1(xtxt,y0)+(1w)pθ2(xtxt,y0)wp_{\theta_{1}}(x_{t}|x_{t},y_{0})+(1-w)p_{\theta_{2}}(x_{t}|x_{t},y_{0}) to use in the reverse diffusion process and generate our predicted sequences.

Data augmentation ideas

The current lack of sequence and structure pair data currently bottle necks inverse protein folding models. Following (Hsu et al. 2022), we can use AlphaFold2 to predict the protein structures from a dataset such as UniRef50 sequences to augment our protein sequence and structure dataset. We may consider adding a pertubed AlphaFold structures as (Dauparas et al. 2022) show that adding a small gaussian noise to the AlphaFold protein structures improves performance. We can also consider noisy student self distillation. Following (Jumper et al. 2021) we can take protein structures lacking sequence info, predict sequences using our inverse folding model, and then add our structure and predicted sequence data to augment our dataset and retrain our model.

References

  • (Luo 2022) Understanding Diffusion Models: A Unified Perspective

  • (Weng 2021) (What are diffusion models?)

  • (Sohl-Dickstein 2015) Deep Unsupervised Learning using Nonequilibrium Thermodynamics

  • (Ho 2020) Denoising Diffusion Probabilistic Models

  • (Wu 2022) Protein structure generation via folding diffusion

  • (Hsu 2022) Learning inverse folding from millions of predicted structures

  • (Ho 2022) Classifier-Free Diffusion Guidance

  • (Bao 2023) One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale

  • (Rombach 2022) High-Resolution Image Synthesis with Latent Diffusion Models

  • (Austin 2021) Structured Denoising Diffusion Models in Discrete State-Spaces

  • (Anand 2022) Protein Structure and Sequence Generation with Equivariant Denoising Diffusion Probabilistic Models

  • (Hoogeboom 2022) Autoregressive Diffusion Models

  • (Yi 2023) Graph Denoising Diffusion for Inverse Protein Folding

  • (Lugmayr 2022) RePaint: Inpainting using Denoising Diffusion Probabilistic Models

  • (Jing 2021) Learning from Protein Structure with Geometric Vector Perceptrons

  • (Alamdari 2023) Protein generation with evolutionary diffusion: sequence is all you need

  • (Dauparas 2022) Robust deep learning--based protein sequence design using ProteinMPNN

  • (Jumper 2021) Highly accurate protein structure prediction with AlphaFold