今回の記事ではSegment Anything Modelによるマスクの自動生成の実装について紹介します。
実際にマスクの自動生成を実装してみましょう。
Google Colabを使用して簡単に実装できますので、ぜひ最後までご覧ください。

Segment Anythingとは

Segment Anythingとは

Segment Anything(SA)は、画像のセグメンテーション(画像の部分ごとの分割)のための新しいタスク、モデル、データセットを提案しています。効率的なモデルをデータ収集ループで使用することにより、11Mのライセンスされたプライバシーに配慮した画像に対して10億枚以上のマスクを提供する、これまでで最も大規模なセグメンテーションデータセットが構築されました。

このモデルは、プロンプトに対応するように設計・学習されており、新しい画像の分布やタスクに対してゼロショット(事前学習のみ)で適応できます。多くのタスクで評価され、ゼロショット性能が非常に優れており、従来の完全に教師あり学習による結果と競合するか、それを上回ることがしばしば確認されています。

SAプロジェクトでは、Segment Anything Model(SAM)と、1Bのマスクと11Mの画像が含まれるデータセット(SA-1B)を提供しており、コンピュータビジョンの基盤モデルに関する研究を促進することを目的としています。

Segment Anything Modelとは

Segment Anything Model (SAM) は、目的のオブジェクトを示すプロンプトが与えられたオブジェクトのマスクを予測するモデルです。このモデルは、まず画像を画像埋め込みに変換し、プロンプトから高品質のマスクを効率的に生成することができます。

SamPredictorクラスは、モデルにプロンプトを与えるための簡単なインターフェイスを提供します。これは、ユーザーがまずset_imageメソッドを使って画像を設定し、必要な画像埋め込みを計算することができます。その後、predictメソッドによってプロンプトを与え、そのプロンプトからマスクを効率よく予測することができます。このモデルは、点と箱の両方のプロンプトと、前回の予測の繰り返しによるマスクを入力として受け取ることができます。

オブジェクトマスクの自動生成

SAMはプロンプトを効率的に処理できるため、画像全体のマスクを画像上の多数のプロンプトをサンプリングすることで生成することができます。この方法は、データセットSA-1Bを生成するために使用されました。

クラスSamAutomaticMaskGeneratorは、この機能を実装しています。これは、画像上のグリッドで単一点入力プロンプトをサンプリングすることによって機能し、それぞれからSAMが複数のマスクを予測できます。次に、マスクの品質がフィルタリングされ、非最大抑制を使用して重複が除去されます。追加のオプションにより、マスクの品質と量をさらに向上させることができます。たとえば、画像の複数のクロップに対して予測を実行するか、小さな切り離された領域や穴を取り除くためにマスクを後処理することができます。

詳細は以下のリンクからご覧ください。

導入

インストール

ここからはGoogle colabを使用して実装していきます。
(Google colabの使用方法はこちら⇨使い方

※今回紹介するコードは以下のリンクからもご覧いただけます。

Open In Colab

まずはGPUを使用できるように設定をします。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

GPUの設定が終わったら、Googleドライブをマウントします。

from google.colab import drive
drive.mount('/content/drive')
%cd ./drive/MyDrive

パッケージをクローンします。

!git clone https://github.com/facebookresearch/segment-anything.git
%cd segment-anything

必要なライブラリをインストールします。

!pip install -e .
!pip install opencv-python pycocotools matplotlib onnxruntime onnx

モデルとサンプルデータのダウンロード

今回使用するモデルとサンプル画像を格納するフォルダを作成してから、それぞれをダウンロードします。

import os
os.makedirs("/content/drive/My Drive/segment-anything/checkpoint", exist_ok=True)
os.makedirs("/content/drive/My Drive/segment-anything/images", exist_ok=True)
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg    
!wget -P checkpoint  https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

以上で準備が完了しました。

実装

サンプル画像を確認する

まずは必要な関数を定義して、サンプル画像を表示してみます。

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

サンプル画像を表示します。

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

自動マスク生成

自動マスク生成を実行するには、SamAutomaticMaskGeneratorクラスにSAMモデルを提供してください。以下のパスをSAMのチェックポイントに設定してください。CUDAで実行し、デフォルトのモデルを使用することをお勧めします。

from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

マスクを生成するには、画像でgenerateを実行するだけです。

masks = mask_generator.generate(image)

マスク生成は、マスクに関するさまざまなデータを含む辞書であるマスクのリストを返します。これらのキーは次のとおりです。

segmentation:マスク area:ピクセル単位でのマスクの面積
bbox:XYWH形式のマスクの境界ボックス
predicted_iou:マスクの品質に対するモデル自身の予測
point_coords:このマスクを生成したサンプリングされた入力点
stability_score:マスク品質の追加指標
crop_box:XYWH形式で、このマスクを生成するために使用された画像のクロップ

print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

自動マスク生成のオプション

自動マスク生成には、どれだけ密にポイントがサンプリングされるかや、低品質または重複マスクを除去するためのしきい値を制御するいくつかの調整可能なパラメータがあります。さらに、画像のクロップで生成を自動的に実行して小さなオブジェクトのパフォーマンスを向上させることができ、また、ポストプロセッシングで紛れ込んだピクセルや穴を除去することができます。ここに、より多くのマスクをサンプリングする例の設定があります。

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)
masks2 = mask_generator_2.generate(image)
len(masks2)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show() 

まとめ

最後までご覧いただきありがとうございました。

今回の記事ではSegment Anything Modelによるマスクの自動生成の実装を紹介しました。