Upscaling Images on TPU VMs to 6k Resolution w/ Stable Diffusion

Acknowledgements

My access to TPUs is graciously provided by the TPU Research Cloud ! This blog post and my research efforts would not be possible without their support !

In this post, we’ll discuss our implementation of a memory efficient procedure to upscale images to 6k resolution using stable diffusion. Stable diffusion is a state of the art image generation technology which generates images from text prompts. In this post, we’ll assume general familiarity with the theoretical foundation behind how stable diffusion works.

HuggingFace’s leading implementation of the stable diffusion upscaler scales images up by a factor of 4x, i.e. it scales an image of resolution 128 x 128 to one of resolution 512 x 512. The quality of upscaled images is quite good but the stable diffusion based upscaling algorithm is memory heavy. Indeed, I was surprised to find that the memory requirements in the decoder scale as a fourth power of the height/width! This quartic dependence creates a memory wall which is hard to bypass even when scaling modest sized images on state of the art hardware. For example, a 40 GB A100 GPU runs out of memory when attempting to upscale a 768 x 768 image to 3072 x 3072.

This blog discusses how to adjust the upscaling implementation to support scaling 1536 x 1536 images to 6144 x 6144 on a single TPU 16 GB device process. Meaning, we will be able to essentially simultaneously generate 8, 6k resolution upscaled images using all 8 processes of a v3 TPU VM. 

The out of memory issues when trying to scale images to 6k resolution come from two issues: memory bottlenecks in the denoising loop and memory bottlenecks in the decoding step. We will discuss lifting bottlenecks in the denoising procedure first and then discuss the decoding step.

Scaling the Denoising Loop to Support 6k Resolution Images

The key idea here is that we will tile the latent during the denoising loop. To tile the image we split it into four quadrants of equal size 768 x 768 which don’t overlap and cover the top right, top left, bottom right and bottom left quadrants. 

Then, on each of these quadrants we apply the denoising step separately.

If we proceed in this way, we will not run out of memory and will be able to generate latents corresponding to 6k resolution images. 

Scaling the Decoding to Support 6k Resolution Images

Decoding requires a lot of memory and must be done on the VM CPU. Even on the relatively high RAM v3 VM CPU, we need to make a few key adjustments. One of these adjustments is to replace the attention mechanism in the variational autoencoder (VAE) with a linear approximation to the standard attention mechanism. This is key because we are dealing with images with large height/width and the memory required to compute attention in the original VAE implementation, scales quadratically in the number of pixels in the image and therefore, as a fourth power in the height/width! Replacing the standard attention with a linear approximation reduces the dependence to merely quadratic in the height/width which is enough to avoid out of memory errors on the v3 VM CPU. In particular, we use the linear attention mechanism implemented in the performer-pytorch library.

Code

Code for implementing linear attention in the autoencoder is located here: https://github.com/lk-wq/diffusers/blob/lk-wq-nbc-ns/src/diffusers/models/vae.py

Below we have some of our upscaled results:

Image Pair 1:

Original Image @ 768 x 768 resolution

Upscaled Image @ 6k resolution (resized to 2900 x 2900 to fit under 20 MB limit of this blog platform(

Image Pair 2

Original @ 768 x 768 resolution

Upscaled image @ 6144 x 6144 (resized to 3072 x 3072 to fit under 20 MB limit of this blog platform)

michael jemison