Training multi-concept DreamBooth stable diffusion models

Over the past few months I've trained several DreamBooth models for clients to inject new concepts into various base stable diffusion models. In each case, the client wanted to insert several new concepts into the model, as opposed to the typical case of inserting a single concept like a new face. In this post, I'll describe the process I used to train these models and the heuristics I've developed for hyperparameters.

What is DreamBooth?

DreamBooth was introduced two months ago in the Google Paper DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-Driven Generation, which presents a novel approach for personalizing text-to-image diffusion models. By fine-tuning a pretrained model with just a few (3-5) images of a subject, it enables the generation of photorealistic images of the subject in various contexts, preserving key features across diverse scenes, poses, and lighting conditions.

DreamBooth paper

The paper describes the process of fine-tuning a pretrained model for a single new concept, but in practice, clients often want to inject multiple new concepts into a single model. For example, Venngage had several images of several different objects and wished to have one fine-tuned model that could be prompted to generate any of the objects.

Initial experiments

I began with photos of my wife and I, since I know our faces well. These were a sample of the training images:

Jimsam training

I knew I wanted to use the diffusers library, since it let's you swap base SD checkpoints easily. On Discord, I found that Shivam Shrirao had a fork of the diffusers library that adapted the idea of the DreamBooth paper to multiple concepts. The results looked very promising, so I decided to use this fork for the rest of my work.

Jimsam generations

The main difference from the original diffusers DreamBoooth implementation is that the fork takes a concepts list as input, and then fine-tunes the model on each concept. Each concept has an instance prompt, which contains a rare token selected from the CLIP text encoder's vocabulary used for the concept's name (ie. zwx is the rare token in the example below). As in the original DreamBooth paper, we assign the new concept to a class (ie. man, dog, cat) and supply a set of regularization images of that class generated from the base SD checkpoint.

[
    {
        "instance_prompt":      "photo of zwx dog",
        "class_prompt":         "photo of a dog",
        "instance_data_dir":    "../../../data/alvan",
        "class_data_dir":       "../../../data/dog"
    }
]

This was the start of a long journey where I built custom DreamBooth models for several companies, including one graphic design company that wanted to generate new artwork similar to their human created artwork.

Some of the original human-created SVGs looked like this, which had things like construction workers but no dogs. The goal was to mimic the style.

DreamBooth icons for training

These were some of the generated images for eight seeds of the prompt "A construction worker walking a dog" as training progresses from left to right:

DreamBooth icons generated

The attention maps for these were really coherent, with the words "worker" and "dog" matching up with the correct parts of the image.

DreamBooth attention map

This came in handy when we wanted to modify good images slightly without changing the whole scene. On the left is the old prompt, and on the right is "A construction worker walking a cat".

DreamBooth prompt to prompt

Without constraining the attention maps, the model generates the image on the right, which has a much larger construction worker and a different background.:

DreamBooth no prompt to prompt