Lo-Fi Makes Data Parallel Fine Tuning Painless Multiple TPU VMs
Google’s TPU TRC graciously granted me access to TPU VMs to conduct deep learning research and development projects. To take full advantage of these resources, I wanted to use distributed data parallel training across all of the VMs I had access to. Lo-Fi, (Wortsman 2022b) a recent approach which enables low communication data parallel training, enabled me to pool the compute across multiple TPU VMs leading to much faster model training without the headaches usually required to facilitate distributed training.
Lo-Fi Reduces Data Transfer Overhead by Greater than 1000x on TPUs
Data parallel training promises to increase the speed at which a model can be trained on a given dataset, i.e. N data parallel processes should ideally reduce training time by a factor of N. (In practice, this is true to a point, see You 2017, for example.) However, traditional data parallel schemes require synchronized data sharing which could require the synchronized exchange of hundreds of GB of data on a minute to minute basis.
It was unclear to me how to quickly set up communication between separate VMs in the limited time I was granted access to TRC cloud TPUs in order to facilitate traditional data parallel training. Because of these difficulties, instead of setting up a traditional data parallel scheme, I decided to try Lo-Fi, a novel scheme for leveraging multiple computational devices in distributed data parallel training with minimal communication, introduced in Wortsman 2022b.
Roughly, Lo-Fi works as follows: multiple independent computational devices train a neural network initialized from a pretrained base model on shards of a larger dataset. Final model weights are generated when weights from these individual devices are averaged. The authors of Lo-Fi found that models fine tuned in this way across a variety of tasks performed as well as models fine tuned through traditional data parallel methods. This approach radically reduces communication compared to traditional data parallel training, which entails synchronized communication after every weight update. In the course of training our models on TPUs we trained for hundreds of thousands of steps. In a traditional data parallel training scheme this would require sharing data between all of our TPU VMs at each step. With Lo-Fi we reconciled parameters only 5 times over 5 or so days of training, meaning we reduced our data transfer times by a factor of many thousands.
Intuition
My interpretation of the success of Lo-Fi is inspired by the large scale geometric conceptualization of the loss landscape discussed in Fort, Jastrzebski 2019. In this paper, the authors propose a geometric model of deep learning loss landscapes. In their model, as a neural network evolves during training, it enters one of many high dimensional (low codimension in parameter space) valleys with a low loss center. During training a model may bounce around the lowest and best performing region of the valley.
Models fine tuned from a pretrained model all stay within a valley during further fine tuning, in the sense that they are linear mode connected (see Wortsman 2022a). The general goal of fine tuning is to ensure that the fine tuned models land on a very low loss region of the valley associated with the fine tuning task. In the Lo-Fi procedure independent models find different positions in the valley due to the inherent randomness in the data sharding mechanism. A model trained on any given shard may land on one side or the other of the value. These random positions in weight space find a central location closer to the low loss center of the valley when averaged together. This intuition is also consistent with the success of other weight averaging methods such as stochastic weight averaging (Ismailov 2018) and model soups (Wortsman 2022a).
Implementation
I used HuggingFace’s leading open source Diffusers library to fine tune an open source pre-trained stable diffusion model on an art dataset of high resolution image/prompt pairs. The Google TFRC graciously allotted five v3 TPU VMs. To use Lo-Fi in this scenario, I split the dataset into 5 sections of equal size and distributed one shard to each of the VMs. I ran my own fork of HuggingFace’s Jax TPU stable diffusion fine tuning code on each VM. My code is largely the same as that in the official HuggingFace repo, however I added functionality to periodically push model weights to a GCP bucket. This functionality is important for ensuring progress is retained in the event of a VM crash or preemption, but is critical for enabling simple data parallelization through Lo-Fi.
Practical Details
In practice, the Lo-Fi communication was almost painless. About once every 30 minutes each VM pushed its current model weights to a GCP Bucket. Then about once a day, another VM pulled down the 5 versions of the model weights from the other VMs and took an average. Finally, the average is pushed back to the GCP Bucket where the other VMs can then download it and resume training.
Training continued until sampled images looked like they were degrading due to overfitting.
Code
The code for training is located at https://github.com/lk-wq/diffusers/blob/main/examples/text_to_image/dpl.py.
The code that implements lofi is located at https://github.com/lk-wq/diffusers/blob/main/examples/text_to_image/lofi3.py.
Citations
Wortsman et al 2022a, Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time, https://arxiv.org/abs/2203.05482
Wortsman et al 2022b, lo-fi: distributed fine-tuning without communication, https://arxiv.org/abs/2210.11948
Ismailov 2018, Averaging Weights Leads to Wider Optima and Better Generalization, https://arxiv.org/abs/1803.05407
You et al 2017, Large Batch Training of Convolutional Networks, https://arxiv.org/abs/1708.03888
Fort, Jastrzebski 2019, Large Scale Structure of Neural Network Loss Landscapes, https://arxiv.org/pdf/abs/1906.04724.pdf
Results
Below are some images generated by the fine tuned diffusion model. In addition to using Lo-Fi I used a novel upsampling technique to generate large scale (up to resolution 6144 x 6144) images using stable diffusion while staying within the memory constraints of the TPU v3 VM. This is a challenge because a single TPU process ordinarily could not support scaling an image to 6K resolution and would hit a memory wall at resolutions below even 3072 x 3072. I detail how we overcome this bottleneck on TPUs in another post. The images below were generated by the fine tuned stable diffusion model, upscaled using stable diffusion to 6k resolution and resized down to 3k to fit under the 20 MB image size limit of this blog hosting platform.
Here are some examples of image outputs:
Prompt: moody romantic bohemian floral painting in neutral colors
Prompt: moody romantic bohemian floral painting in neutral colors, including blue and tan
Prompt: moody romantic bohemian floral painting in the style of dr seuss