2022年8月に公開された、高性能画像生成モデルである「Stable Diffusion」を実装する方法を紹介するシリーズです。
第4回目では「Dreambooth Concepts Library」による追加学習の方法をご紹介します。
任意の画像を追加学習させたオリジナルモデルから画像を生成して遊んでみましょう。
Google colabを使用して簡単に実装することができますので、ぜひ最後までご覧ください。
今回の目標
・Dreambooth Concepts Libraryとは
・追加学習の実装
・推論
Stable Diffusionとは
Stable Diffusionは拡散モデルによる画像生成モデルで、スタートアップ企業であるStability AIが2022年8月に「Stable Diffusion」と、それを使用したサービス「DreamStudio」のβ版をリリースしました。
Stable Diffusionは、SNSなどで話題になっている「Midjourney」と同様で、テキストから画像を生成することができます。
無料で公開されており、商用利用も可能なライセンスの下でリリースされているため、様々なシーンでの活用が期待されます。
Latent Diffusionをベースとしており、非常に大規模なデータセットであるLAION-5Bを用いてトレーニングされています。
参考URL:https://huggingface.co/blog/stable_diffusion
実装例は以下のリンクよりご確認ください。
「Dreambooth Concepts Library」とは
Stable Diffusion Dreambooth Concepts Libraryは、Text to Image拡散モデルに対して、任意の画像を追加学習させる手法です。
事前に学習されたText to Imageの拡散モデルに対して、学習させたい特定の被写体が写る数枚程度の画像と、識別子となるプロンプトを与えることで、追加学習することができるようになります。
これにより、Text to Imageの拡散モデルを自分流にカスタマイズすることが可能になります。
今回の記事では、Stable Diffusionの事前学習済みモデルである「stable-diffusion-v1-4」の追加学習方法を紹介します。
詳細は以下のリンクよりご確認下さい。
Hugging Face(Stable Diffusion) : https://huggingface.co/rinna/japanese-stable-diffusion
GitHub(Stable Diffusion) :https://github.com/rinnakk/japanese-stable-diffusion
Dreambooth Concepts Library :https://huggingface.co/sd-dreambooth-library
Stable Diffusionの導入
まずは、事前学習済みモデルである「stable-diffusion-v1-4」を使用するために、以下の手順を行います。
Hugging Faceの登録とNewTokenの発行
Hugging Faceのアカウント作成
初めにHuggingFaceのアカウントを作成します。
※Hugging Faceとは米国のHugging Face社が提供している、自然言語処理に特化したディープラーニングのフレームワークです。
ソースコードは全てGitHub上で公開されており、誰でも無料で使うことができます。
HuggingFaceにアクセスし、画面右上のSignUpよりアカウントを作成することができます。
登録したメールアドレスに認証メールが届くので、メールに記載されたリンクにアクセスすれば、アカウント登録は完了です。
Access Repositoryの承諾
こちらのCompVis/stable-diffusion-v1-4にアクセスし記載の内容を確認の上、「Access Repository」をクリックすることで権限を得ることができます。
この時点で、モデルの作者にメールアドレスとユーザー名が共有されることになりますので注意してください。
モデルの使用にあたっては、意図的に違法または有害な出力やコンテンツを作成・共有することが禁止されています。
CreativeML OpenRAIL Licenseに準拠した上で、再配布や商用利用のルールなどについての記載に同意する必要があります。
Access Tokensの発行
画面右上のアカウントのアイコンから[Settings]->[Access Tokens]に移動しNewTokenを発行します。
後ほど使用しますので、メモしておきましょう。
Google colabの準備
ここからは、Google colab環境で進めていきます。
なお、今回紹介するコードは公式チュートリアルを参考にしています。
合わせてご確認ください。
まずはGPUを使用できるように設定をします。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
次にgoogleドライブをマウントして、今回作成した画像を保存するフォルダを作成します。
フォルダ名は変えても問題ありありません。
from google.colab import drive
drive.mount('/content/drive')
#保存フォルダを作成する
!mkdir -p '/content/drive/My Drive/Stable Diffusion_main/'
%cd '/content/drive/My Drive/Stable Diffusion_main/'
「Dreambooth Concepts Library」の実装
準備
まずは必要なライブリをインストールします。
!pip install -qq git+https://github.com/huggingface/diffusers.git accelerate tensorboard transformers ftfy gradio
!pip install -qq "ipywidgets>=7,<8"
!pip install -qq bitsandbytes
Hugging Faceにログインします。
from huggingface_hub import notebook_login
!git config --global credential.helper store
notebook_login()
実行後表示される箇所に、先ほど取得したAccess Tokenを入力します。
import argparse
import itertools
import math
import os
from contextlib import nullcontext
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import PIL
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
import bitsandbytes as bnb
def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
入力データ(モデルと画像)
事前学習済みモデルと入力データを定義します。
今回は「stable-diffusion-v1-4」を使用します。
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
次に学習に使用する画像を定義します。
ここでは、以下の6枚の画像を使用します。
urls = [
'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/04/test1-2-scaled.jpg',
'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/308-scaled.jpeg',
'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/737.jpeg',
'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/DSC_0626-scaled.jpeg',
'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/PAP_0608.jpeg',
'https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/10/test-scaled.jpg'
]
この画像を表示してみます。
import requests
import glob
from io import BytesIO
def download_image(url):
try:
response = requests.get(url)
except:
return None
return Image.open(BytesIO(response.content)).convert("RGB")
images = list(filter(None,[download_image(url) for url in urls]))
save_path = "./my_concept"
if not os.path.exists(save_path):
os.mkdir(save_path)
[image.save(f"{save_path}/{i}.jpeg") for i, image in enumerate(images)]
image_grid(images, 1, len(images))
入力画像を表示することができました。
プロンプト設定
次に学習させる被写体を定義します。
今回は上記の画像のように、国鉄時代の型式の車両を特徴としたモデルを作成するため、そのことを明示します。
追加学習させたい対象であることを示す「”sks”+オブジェクト名」でプロンプトを指定します。
# オブジェクト、画風の説明
instance_prompt = "a photo of sks tranicar"
#instance_prompt : sks
#オブジェクトまたは画風が何であるかの適切な説明と、イニシャライザ ワード sks を含むプロンプトを指定。
# コンセプトクラスの指定、画質が向上
prior_preservation = False
prior_preservation_class_prompt = "a photo of tranicar"
num_class_images = 12
sample_batch_size = 2
prior_loss_weight = 0.5
prior_preservation_class_folder = "./class_images"
class_data_root=prior_preservation_class_folder
class_prompt=prior_preservation_class_prompt
from pathlib import Path
from torchvision import transforms
class DreamBoothDataset(Dataset):
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(Path(class_data_root).iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
padding="do_not_pad",
truncation=True,
max_length=self.tokenizer.model_max_length,
).input_ids
return example
class PromptDataset(Dataset):
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
import gc
if(prior_preservation):
class_images_dir = Path(class_data_root)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < num_class_images:
pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name_or_path, use_auth_token=True, revision="fp16", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_attention_slicing()
pipeline.set_progress_bar_config(disable=True)
num_new_images = num_class_images - cur_class_images
print(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=sample_batch_size)
for example in tqdm(sample_dataloader, desc="Generating class images"):
with torch.autocast("cuda"):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
pipeline = None
gc.collect()
del pipeline
with torch.no_grad():
torch.cuda.empty_cache()
text_encoder = CLIPTextModel.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=True
)
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae", use_auth_token=True
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", use_auth_token=True
)
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path,
subfolder="tokenizer",
use_auth_token=True,
)
from argparse import Namespace
args = Namespace(
pretrained_model_name_or_path=pretrained_model_name_or_path,
resolution=512,
center_crop=True,
instance_data_dir=save_path,
instance_prompt=instance_prompt,
learning_rate=5e-06,
max_train_steps=450,
train_batch_size=1,
gradient_accumulation_steps=2,
max_grad_norm=1.0,
mixed_precision="no", # set to "fp16" for mixed-precision training.
gradient_checkpointing=True, # set this to True to lower the memory usage.
use_8bit_adam=True, # use 8bit optimizer from bitsandbytes
seed=3434554,
with_prior_preservation=prior_preservation,
prior_loss_weight=prior_loss_weight,
sample_batch_size=2,
class_data_dir=prior_preservation_class_folder,
class_prompt=prior_preservation_class_prompt,
num_class_images=num_class_images,
output_dir="dreambooth-concept",
)
from accelerate.utils import set_seed
def training_function(text_encoder, vae, unet):
logger = get_logger(__name__)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
)
set_seed(args.seed)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
unet.parameters(), # only optimize unet
lr=args.learning_rate,
)
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
)
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# concat class and instance examples for prior preservation
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn
)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
# Move text_encode and vae to gpu
text_encoder.to(accelerator.device)
vae.to(accelerator.device)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
noise, noise_prior = torch.chunk(noise, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
# Compute prior loss
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {"loss": loss.detach().item()}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
# Create the pipeline using using the trained modules and save it.
if accelerator.is_main_process:
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=accelerator.unwrap_model(unet),
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_dir)
学習
学習を実装します。
1時間程度で学習が完了します。
import accelerate
accelerate.notebook_launcher(training_function, args=(text_encoder, vae, unet))
with torch.no_grad():
torch.cuda.empty_cache()
推論
学習結果を使って、画像を生成してみます。
pipe = StableDiffusionPipeline.from_pretrained(
args.output_dir,
torch_dtype=torch.float16,
).to("cuda")
from torch import autocast
prompt = "a photo of a traincar by a cherry tree"
num = 100
for i in range(num):
with autocast("cuda"):
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
image.save(str(i) + 'c.png')
入力画像に似た画像が生成されました。
数枚の画像を入力として学習することで、オリジナルモデルを作成できることがわかりました。
まとめ
最後までご覧いただきありがとうございました。
今回は「Dreambooth Concepts Library」による追加学習の方法をご紹介しました。
任意の画像を追加学習させたオリジナルモデルから、画像を生成することができました。
なお、作成した画像をより高解像度にする方法も別の記事で紹介しています。ぜひご覧ください。