2022年8月に公開された、高性能画像生成モデルである「Stable Diffusion」を実装する方法を紹介するシリーズです。

第4回目では「Dreambooth Concepts Library」による追加学習の方法をご紹介します。

任意の画像を追加学習させたオリジナルモデルから画像を生成して遊んでみましょう。

Google colabを使用して簡単に実装することができますので、ぜひ最後までご覧ください。

今回の目標

・Dreambooth Concepts Libraryとは

・追加学習の実装

・推論

Stable Diffusionとは

Stable Diffusionは拡散モデルによる画像生成モデルで、スタートアップ企業であるStability AIが2022年8月に「Stable Diffusion」と、それを使用したサービス「DreamStudio」のβ版をリリースしました。

Stable Diffusionは、SNSなどで話題になっている「Midjourney」と同様で、テキストから画像を生成することができます。

無料で公開されており、商用利用も可能なライセンスの下でリリースされているため、様々なシーンでの活用が期待されます。

Latent Diffusionをベースとしており、非常に大規模なデータセットであるLAION-5Bを用いてトレーニングされています。

参考URL:https://huggingface.co/blog/stable_diffusion

実装例は以下のリンクよりご確認ください。

画像生成・超解像・画像修復・アニメーション

Stable Diffusion(画像生成) 画像生成として話題のStable Diffusionを紹介しています。 Super Resolution(超解像) 最新の実用的な超解像を紹介します。 画像や動画を…

「Dreambooth Concepts Library」とは

Stable Diffusion Dreambooth Concepts Libraryは、Text to Image拡散モデルに対して、任意の画像を追加学習させる手法です。

事前に学習されたText to Imageの拡散モデルに対して、学習させたい特定の被写体が写る数枚程度の画像と、識別子となるプロンプトを与えることで、追加学習することができるようになります。

これにより、Text to Imageの拡散モデルを自分流にカスタマイズすることが可能になります。

今回の記事では、Stable Diffusionの事前学習済みモデルである「stable-diffusion-v1-4」の追加学習方法を紹介します。

詳細は以下のリンクよりご確認下さい。

Stable Diffusionの導入

まずは、事前学習済みモデルである「stable-diffusion-v1-4」を使用するために、以下の手順を行います。

Hugging Faceの登録とNewTokenの発行

Hugging Faceのアカウント作成

初めにHuggingFaceのアカウントを作成します。

※Hugging Faceとは米国のHugging Face社が提供している、自然言語処理に特化したディープラーニングのフレームワークです。
ソースコードは全てGitHub上で公開されており、誰でも無料で使うことができます。

HuggingFaceにアクセスし、画面右上のSignUpよりアカウントを作成することができます。

登録したメールアドレスに認証メールが届くので、メールに記載されたリンクにアクセスすれば、アカウント登録は完了です。

STEP
1

Access Repositoryの承諾

こちらのCompVis/stable-diffusion-v1-4にアクセスし記載の内容を確認の上、「Access Repository」をクリックすることで権限を得ることができます。

この時点で、モデルの作者にメールアドレスとユーザー名が共有されることになりますので注意してください。

モデルの使用にあたっては、意図的に違法または有害な出力やコンテンツを作成・共有することが禁止されています。

CreativeML OpenRAIL Licenseに準拠した上で、再配布や商用利用のルールなどについての記載に同意する必要があります。

STEP
2

Access Tokensの発行

画面右上のアカウントのアイコンから[Settings]->[Access Tokens]に移動しNewTokenを発行します。

後ほど使用しますので、メモしておきましょう。

STEP
3

Google colabの準備

ここからは、Google colab環境で進めていきます。

なお、今回紹介するコードは公式チュートリアルを参考にしています。

合わせてご確認ください。

学習用実装:Open In Colab

推論用実装:Open In Colab

まずはGPUを使用できるように設定をします。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

次にgoogleドライブをマウントして、今回作成した画像を保存するフォルダを作成します。

フォルダ名は変えても問題ありありません。

from google.colab import drive 
drive.mount('/content/drive')

#保存フォルダを作成する
 !mkdir -p '/content/drive/My Drive/Stable Diffusion_main/'
%cd '/content/drive/My Drive/Stable Diffusion_main/'

「Dreambooth Concepts Library」の実装

準備

まずは必要なライブリをインストールします。

!pip install -qq git+https://github.com/huggingface/diffusers.git accelerate tensorboard transformers ftfy gradio
!pip install -qq "ipywidgets>=7,<8"
!pip install -qq bitsandbytes

Hugging Faceにログインします。

from huggingface_hub import notebook_login
!git config --global credential.helper store
 
notebook_login()

実行後表示される箇所に、先ほど取得したAccess Tokenを入力します。

import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random
 
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
 
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
import bitsandbytes as bnb

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

入力データ(モデルと画像)

事前学習済みモデルと入力データを定義します。

今回は「stable-diffusion-v1-4」を使用します。

pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"

次に学習に使用する画像を定義します。

ここでは、以下の6枚の画像を使用します。

urls = [
    'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/04/test1-2-scaled.jpg',
    'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/308-scaled.jpeg',
    'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/737.jpeg',
    'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/DSC_0626-scaled.jpeg',
    'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/PAP_0608.jpeg',
    'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/test-scaled.jpg'
          ]

この画像を表示してみます。

import requests
import glob
from io import BytesIO

def download_image(url):
  try:
    response = requests.get(url)
  except:
    return None
  return Image.open(BytesIO(response.content)).convert("RGB")

images = list(filter(None,[download_image(url) for url in urls]))
save_path = "./my_concept"
if not os.path.exists(save_path):
  os.mkdir(save_path)
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
image_grid(images, 1, len(images))

入力画像を表示することができました。

プロンプト設定

次に学習させる被写体を定義します。

今回は上記の画像のように、国鉄時代の型式の車両を特徴としたモデルを作成するため、そのことを明示します。

追加学習させたい対象であることを示す「”sks”+オブジェクト名」でプロンプトを指定します。

# オブジェクト、画風の説明
instance_prompt = "a photo of sks tranicar"

#instance_prompt : sks
#オブジェクトまたは画風が何であるかの適切な説明と、イニシャライザ ワード sks を含むプロンプトを指定。

# コンセプトクラスの指定、画質が向上
prior_preservation = False
prior_preservation_class_prompt = "a photo of tranicar"


num_class_images = 12 
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"
class_data_root=prior_preservation_class_folder
class_prompt=prior_preservation_class_prompt
from pathlib import Path
from torchvision import transforms

class DreamBoothDataset(Dataset):
    def __init__(
        self,
        instance_data_root,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_data_root = Path(instance_data_root)
        if not self.instance_data_root.exists():
            raise ValueError("Instance images root doesn't exists.")

        self.instance_images_path = list(Path(instance_data_root).iterdir())
        self.num_instance_images = len(self.instance_images_path)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(Path(class_data_root).iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
        if not instance_image.mode == "RGB":
            instance_image = instance_image.convert("RGB")
        example["instance_images"] = self.image_transforms(instance_image)
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            padding="do_not_pad",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                padding="do_not_pad",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids
        
        return example


class PromptDataset(Dataset):
    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example
import gc
if(prior_preservation):
    class_images_dir = Path(class_data_root)
    if not class_images_dir.exists():
        class_images_dir.mkdir(parents=True)
    cur_class_images = len(list(class_images_dir.iterdir()))

    if cur_class_images < num_class_images:
        pipeline = StableDiffusionPipeline.from_pretrained(
            pretrained_model_name_or_path, use_auth_token=True, revision="fp16", torch_dtype=torch.float16
        ).to("cuda")
        pipeline.enable_attention_slicing()
        pipeline.set_progress_bar_config(disable=True)

        num_new_images = num_class_images - cur_class_images
        print(f"Number of class images to sample: {num_new_images}.")

        sample_dataset = PromptDataset(class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)

        for example in tqdm(sample_dataloader, desc="Generating class images"):
            with torch.autocast("cuda"):
                images = pipeline(example["prompt"]).images

            for i, image in enumerate(images):
                image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
        pipeline = None
        gc.collect()
        del pipeline
        with torch.no_grad():
          torch.cuda.empty_cache()
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=True
)
vae = AutoencoderKL.from_pretrained(
    pretrained_model_name_or_path, subfolder="vae", use_auth_token=True
)
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_name_or_path, subfolder="unet", use_auth_token=True
)
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_name_or_path,
    subfolder="tokenizer",
    use_auth_token=True,
)
from argparse import Namespace
args = Namespace(
    pretrained_model_name_or_path=pretrained_model_name_or_path,
    resolution=512,
    center_crop=True,
    instance_data_dir=save_path,
    instance_prompt=instance_prompt,
    learning_rate=5e-06,
    max_train_steps=450,
    train_batch_size=1,
    gradient_accumulation_steps=2,
    max_grad_norm=1.0,
    mixed_precision="no", # set to "fp16" for mixed-precision training.
    gradient_checkpointing=True, # set this to True to lower the memory usage.
    use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
    seed=3434554,
    with_prior_preservation=prior_preservation, 
    prior_loss_weight=prior_loss_weight,
    sample_batch_size=2,
    class_data_dir=prior_preservation_class_folder, 
    class_prompt=prior_preservation_class_prompt, 
    num_class_images=num_class_images, 
    output_dir="dreambooth-concept",
)
from accelerate.utils import set_seed
def training_function(text_encoder, vae, unet):
    logger = get_logger(__name__)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
    )

    set_seed(args.seed)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

    optimizer = optimizer_class(
        unet.parameters(),  # only optimize unet
        lr=args.learning_rate,
    )

    noise_scheduler = DDPMScheduler(
        beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
    )
    
    train_dataset = DreamBoothDataset(
        instance_data_root=args.instance_data_dir,
        instance_prompt=args.instance_prompt,
        class_data_root=args.class_data_dir if args.with_prior_preservation else None,
        class_prompt=args.class_prompt,
        tokenizer=tokenizer,
        size=args.resolution,
        center_crop=args.center_crop,
    )

    def collate_fn(examples):
        input_ids = [example["instance_prompt_ids"] for example in examples]
        pixel_values = [example["instance_images"] for example in examples]

        # concat class and instance examples for prior preservation
        if args.with_prior_preservation:
            input_ids += [example["class_prompt_ids"] for example in examples]
            pixel_values += [example["class_images"] for example in examples]

        pixel_values = torch.stack(pixel_values)
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

        input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids

        batch = {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
        }
        return batch
    
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
    )

    unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)

    # Move text_encode and vae to gpu
    text_encoder.to(accelerator.device)
    vae.to(accelerator.device)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
  
    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")
    global_step = 0

    for epoch in range(num_train_epochs):
        unet.train()
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # Convert images to latent space
                with torch.no_grad():
                    latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
                    latents = latents * 0.18215

                # Sample noise that we'll add to the latents
                noise = torch.randn(latents.shape).to(latents.device)
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
                ).long()

                # Add noise to the latents according to the noise magnitude at each timestep
                # (this is the forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the text embedding for conditioning
                with torch.no_grad():
                    encoder_hidden_states = text_encoder(batch["input_ids"])[0]

                # Predict the noise residual
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                if args.with_prior_preservation:
                    # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
                    noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
                    noise, noise_prior = torch.chunk(noise, 2, dim=0)

                    # Compute instance loss
                    loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()

                    # Compute prior loss
                    prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()

                    # Add the prior loss to the instance loss.
                    loss = loss + args.prior_loss_weight * prior_loss
                else:
                    loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1

            logs = {"loss": loss.detach().item()}
            progress_bar.set_postfix(**logs)

            if global_step >= args.max_train_steps:
                break

        accelerator.wait_for_everyone()
    
    # Create the pipeline using using the trained modules and save it.
    if accelerator.is_main_process:
        pipeline = StableDiffusionPipeline(
            text_encoder=text_encoder,
            vae=vae,
            unet=accelerator.unwrap_model(unet),
            tokenizer=tokenizer,
            scheduler=PNDMScheduler(
                beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
            ),
            safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
            feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
        )
        pipeline.save_pretrained(args.output_dir)

学習

学習を実装します。

1時間程度で学習が完了します。

import accelerate
accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
with torch.no_grad():
    torch.cuda.empty_cache()

推論

学習結果を使って、画像を生成してみます。

pipe = StableDiffusionPipeline.from_pretrained(
        args.output_dir,
        torch_dtype=torch.float16,
    ).to("cuda")
from torch import autocast
prompt = "a photo of a traincar by a cherry tree"
num = 100

for i in range(num):
  with autocast("cuda"):
    image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
    image.save(str(i) + 'c.png')

入力画像に似た画像が生成されました。

数枚の画像を入力として学習することで、オリジナルモデルを作成できることがわかりました。

まとめ

最後までご覧いただきありがとうございました。

今回は「Dreambooth Concepts Library」による追加学習の方法をご紹介しました。

任意の画像を追加学習させたオリジナルモデルから、画像を生成することができました。

なお、作成した画像をより高解像度にする方法も別の記事で紹介しています。ぜひご覧ください。

画像生成・超解像・画像修復・アニメーション

Stable Diffusion(画像生成) 画像生成として話題のStable Diffusionを紹介しています。 Super Resolution(超解像) 最新の実用的な超解像を紹介します。 画像や動画を…