今回の記事ではSegment Anythingについて紹介します。
実際にゼロショットのセグメンテーションを実装してみましょう。
Google Colabを使用して簡単に実装できますので、ぜひ最後までご覧ください。
Segment Anythingとは
Segment Anythingとは
Segment Anything(SA)は、画像のセグメンテーション(画像の部分ごとの分割)のための新しいタスク、モデル、データセットを提案しています。効率的なモデルをデータ収集ループで使用することにより、11Mのライセンスされたプライバシーに配慮した画像に対して10億枚以上のマスクを提供する、これまでで最も大規模なセグメンテーションデータセットが構築されました。
このモデルは、プロンプトに対応するように設計・学習されており、新しい画像の分布やタスクに対してゼロショット(事前学習のみ)で適応できます。多くのタスクで評価され、ゼロショット性能が非常に優れており、従来の完全に教師あり学習による結果と競合するか、それを上回ることがしばしば確認されています。
SAプロジェクトでは、Segment Anything Model(SAM)と、1Bのマスクと11Mの画像が含まれるデータセット(SA-1B)を提供しており、コンピュータビジョンの基盤モデルに関する研究を促進することを目的としています。
Segment Anything Modelとは
Segment Anything Model (SAM) は、目的のオブジェクトを示すプロンプトが与えられたオブジェクトのマスクを予測するモデルです。このモデルは、まず画像を画像埋め込みに変換し、プロンプトから高品質のマスクを効率的に生成することができます。
SamPredictorクラスは、モデルにプロンプトを与えるための簡単なインターフェイスを提供します。これは、ユーザーがまずset_imageメソッドを使って画像を設定し、必要な画像埋め込みを計算することができます。その後、predictメソッドによってプロンプトを与え、そのプロンプトからマスクを効率よく予測することができます。このモデルは、点と箱の両方のプロンプトと、前回の予測の繰り返しによるマスクを入力として受け取ることができます。
オブジェクトマスクの自動生成
SAMはプロンプトを効率的に処理できるため、画像全体のマスクを画像上の多数のプロンプトをサンプリングすることで生成することができます。この方法は、データセットSA-1Bを生成するために使用されました。
クラスSamAutomaticMaskGeneratorは、この機能を実装しています。これは、画像上のグリッドで単一点入力プロンプトをサンプリングすることによって機能し、それぞれからSAMが複数のマスクを予測できます。次に、マスクの品質がフィルタリングされ、非最大抑制を使用して重複が除去されます。追加のオプションにより、マスクの品質と量をさらに向上させることができます。たとえば、画像の複数のクロップに対して予測を実行するか、小さな切り離された領域や穴を取り除くためにマスクを後処理することができます。
詳細は以下のリンクからご覧ください。
導入
インストール
ここからはGoogle colabを使用して実装していきます。
(Google colabの使用方法はこちら⇨使い方)
※今回紹介するコードは以下のリンクからもご覧いただけます。
まずはGPUを使用できるように設定をします。
「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
GPUの設定が終わったら、Googleドライブをマウントします。
from google.colab import drive
drive.mount('/content/drive')
%cd ./drive/MyDrive
パッケージをクローンします。
!git clone https://github.com/facebookresearch/segment-anything.git
%cd segment-anything
必要なライブラリをインストールします。
!pip install -e .
!pip install opencv-python pycocotools matplotlib onnxruntime onnx
モデルとサンプルデータのダウンロード
今回使用するモデルとサンプル画像を格納するフォルダを作成してから、それぞれをダウンロードします。
import os
os.makedirs("/content/drive/My Drive/segment-anything/checkpoint", exist_ok=True)
os.makedirs("/content/drive/My Drive/segment-anything/images", exist_ok=True)
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
!wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg
!wget -P checkpoint https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
以上で準備が完了しました。
実装
サンプル画像を確認する
まずは必要な関数を定義して、サンプル画像を表示してみます。
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
サンプル画像を表示します。
image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()
まず、SAMモデルと予測器を読み込みます。以下のパスを変更して、SAMのチェックポイントを指すようにしてください。最適な結果を得るために、CUDA上で実行し、デフォルトのモデルを使用することをお勧めします。
SAMを使ってオブジェクトを選択する
import sys
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "checkpoint/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
SamPredictor.set_imageを呼び出して画像を処理し、画像埋め込みを生成します。SamPredictorはこの埋め込みを記憶しており、その後のマスク予測に使用します。
predictor.set_image(image)
トラックを選択するには、トラック上の点を選択します。点は(x,y)形式でモデルに入力され、ラベルは1(前景点)または0(背景点)である。複数の点を入力することができますが、ここでは1点のみを使用します。選択した点は、画像上に星印として表示されます。
input_point = np.array([[500, 375]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
SamPredictor.predictで予測します。このモデルは、マスク、それらのマスクの品質予測、および次の反復予測に渡すことができる低解像度マスクのロジットを返します。
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
デフォルト設定であるmultimask_output=Trueを使用すると、SAMは3つのマスクを出力し、scoresはこれらのマスクの品質をモデル自身が評価したものです。この設定は曖昧な入力プロンプトに対応するためのもので、プロンプトに一貫した異なるオブジェクトを判別するのに役立ちます。Falseの場合、1つのマスクのみが返されます。単一の点のような曖昧なプロンプトの場合、たとえ1つのマスクだけが必要であっても、multimask_output=Trueを使用することが推奨されます。scoresで返される最も高いスコアを持つマスクを選択することで、最適な単一のマスクが選択できます。これにより、より良いマスクが得られることがよくあります。
masks.shape # (number_of_masks) x H x W
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
3つのパターンの画像が出力されました。
特定のオブジェクトを追加の点で指定する
単一の入力点が曖昧であり、モデルはそれと一致する複数のオブジェクトを返しています。単一のオブジェクトを取得するために、複数の点を提供することができます。利用可能であれば、前回の反復からのマスクも、予測を助けるためにモデルに提供することができます。複数のプロンプトで単一のオブジェクトを指定する場合、multimask_output=Falseを設定することで単一のマスクをリクエストできます。
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 1])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
masks.shape
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
車を除外し、窓のみを指定するには、背景点(ここでは赤で示されるラベル0)を提供できます。
input_point = np.array([[500, 375], [1125, 625]])
input_label = np.array([1, 0])
mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=mask_input[None, :, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
ボックスを使って特定のオブジェクトを指定する
input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
点とボックスの組み合わせ
点とボックスは、両方のタイプのプロンプトを予測器に含めるだけで組み合わせることができます。ここでは、これを使用してトラックのタイヤのみを選択し、ホイール全体ではなくすることができます。
input_box = np.array([425, 600, 700, 875])
input_point = np.array([[575, 750]])
input_label = np.array([0])
masks, _, _ = predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
バッチ化されたプロンプト入力
SamPredictorは、同じ画像に対して複数の入力プロンプトを受け取ることができ、predict_torchメソッドを使用します。このメソッドは、入力点がすでにtorchテンソルであり、入力フレームに変換済みであることを前提としています。例えば、オブジェクト検出器からの複数のボックス出力があるとします。
input_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=predictor.device)
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()
エンドツーエンドのバッチ推論
すべてのプロンプトが事前に利用可能である場合、SAMをエンドツーエンドの方法で直接実行することができます。これにより、画像間でバッチ処理も可能になります。
image1 = image # truck.jpg from above
image1_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=sam.device)
image2 = cv2.imread('images/groceries.jpg')
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
image2_boxes = torch.tensor([
[450, 170, 520, 350],
[350, 190, 450, 350],
[500, 170, 580, 350],
[580, 170, 640, 350],
], device=sam.device)
画像とプロンプトは、正しいフレームに変換済みのPyTorchテンソルとして入力されます。入力は、画像のリストとしてパッケージ化され、各要素は次のキーを持つ辞書となります。
image
:CHW形式のPyTorchテンソルとしての入力画像。
original_size
:SAMへの入力用に変換する前の画像サイズ(H、W形式)。
point_coords
:バッチ化されたポイントプロンプトの座標。
point_labels
:バッチ化されたポイントプロンプトのラベル。
boxes
:バッチ化された入力ボックス。
mask_inputs
:バッチ化された入力マスク。 プロンプトが存在しない場合、キーを除外することができます。
from segment_anything.utils.transforms import ResizeLongestSide
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
batched_input = [
{
'image': prepare_image(image1, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),
'original_size': image1.shape[:2]
},
{
'image': prepare_image(image2, resize_transform, sam),
'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),
'original_size': image2.shape[:2]
}
]
batched_output = sam(batched_input, multimask_output=False)
出力は、各入力画像の結果に対するリストであり、リストの要素は次のキーを持つ辞書です。
masks
:予測されたバイナリマスクのバッチ化されたtorchテンソルで、オリジナル画像のサイズです。
iou_predictions
:各マスクの品質に対するモデルの予測。 low_res_logits:各マスクの低解像度のlogitsで、後の反復でマスク入力としてモデルに渡すことができます。
batched_output[0].keys()
fig, ax = plt.subplots(1, 2, figsize=(20, 20))
ax[0].imshow(image1)
for mask in batched_output[0]['masks']:
show_mask(mask.cpu().numpy(), ax[0], random_color=True)
for box in image1_boxes:
show_box(box.cpu().numpy(), ax[0])
ax[0].axis('off')
ax[1].imshow(image2)
for mask in batched_output[1]['masks']:
show_mask(mask.cpu().numpy(), ax[1], random_color=True)
for box in image2_boxes:
show_box(box.cpu().numpy(), ax[1])
ax[1].axis('off')
plt.tight_layout()
plt.show()
まとめ
最後までご覧いただきありがとうございました。
今回の記事ではSegment Anything Modelの実装を紹介しました。