今回の記事では画像からテキストに変換できるLLaVA-1.5の実装を紹介します。

Google Colabを使用して簡単に実装できますので、ぜひ最後までご覧ください。

今回の内容

・LLaVA-1.5とは

・LLaVA-1.5の実装

LLaVA-1.5とは

LLaVA 1.5は、視覚言語のクロスモーダルコネクタとして全結合の設計を採用しています。これにより、効率的にデータを処理し、高い性能を発揮することが可能です。CLIP-ViT-L-336pxとMLPプロジェクションを使用し、アカデミックタスク指向のVQAデータとシンプルなレスポンスフォーマッティングプロンプトを追加するだけで、11のベンチマークすべてで最先端の結果を達成しています。

このモデルは、公に利用可能なデータ1.2Mだけを使用しており、8-A100ノードの単一マシンで約1日で完全なトレーニングが完了します。これにより、LMMリサーチがよりアクセス可能になることを期待しています。

LLaVAアーキテクチャは、一般的なアシスタントの構築に向けた主要な構成要素としてますます人気を集めています。LLaVAとMiniGPT-4は、自然なインストラクションに従う能力と視覚的な推論能力の両方で印象的な結果を示しています。これにより、LMMの能力をより深く理解するためのベンチマークが提案されています。

LLaVA 1.5では、MLPクロスモーダルコネクタとアカデミックタスク関連データの組み合わせによって、マルチモーダル理解の能力が向上しています。これは、何百万もの画像テキストペアデータに基づいて訓練された特別に設計された視覚リサンプラーを用いる他のモデルとは対照的です。LLaVAは、LMMのための最もシンプルなアーキテクチャ設計を採用し、わずか600Kの画像テキストペアでの訓練だけで、幅広いベンチマークで最先端の結果を達成しています。

詳細は以下のリンクからご確認ください。

LLaVA-1.5の導入

ここからはGoogle colabを使用して、LLaVA-1.5を実装していきましょう。

Open In Colab

まずはGPUを使用できるように設定をします。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

Googleドライブをマウントします。

from google.colab import drive
drive.mount('/content/drive')
%cd ./drive/MyDrive

公式リポジトリをクローンします。

!git clone https://github.com/haotian-liu/LLaVA
%cd LLaVA

必要なライブラリをインストールします。

!pip install -e .

最後に結果を出力するための関数を定義します。

import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from transformers import TextStreamer


def load_image(image_file):
    if image_file.startswith('http'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def generate_response(input_text):
    load_4bit = True
    load_8bit = not load_4bit

    disable_torch_init()

    model_name = get_model_name_from_path(model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, load_8bit, load_4bit)

    if 'llama-2' in model_name.lower():
        conv_mode = "llava_llama_2"
    elif "v1" in model_name.lower():
        conv_mode = "llava_v1"
    elif "mpt" in model_name.lower():
        conv_mode = "mpt"
    else:
        conv_mode = "llava_v0"

    conv = conv_templates[conv_mode].copy()
    roles = conv.roles if "mpt" not in model_name.lower() else ('user', 'assistant')

    image = load_image(image_file)
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()

    inp = input_text
    print(f"{roles[1]}: ", end="")

    if image is not None:
        if model.config.mm_use_im_start_end:
            inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
        else:
            inp = DEFAULT_IMAGE_TOKEN + '\n' + inp

        conv.append_message(conv.roles[0], inp)
        image = None
    else:
        conv.append_message(conv.roles[0], inp)

    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            do_sample=True,
            temperature=0.2,
            max_new_tokens=1024,
            streamer=streamer,
            use_cache=True,
            stopping_criteria=[stopping_criteria])

    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    conv.messages[-1][-1] = outputs

    return outputs

以上で導入は完了です。

サンプル画像の実装例

まずは公式で用意されているサンプル画像を使用して実装してみます。

image_file:入力画像を指定します。以下の例ではサンプルとして用意されている画像のURLを指定しています。

input_text:画像に対する質問を入力します。

model_path = "liuhaotian/llava-v1.5-7b"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
input_text = "What are things I should be cautious about when I visit this place?"

response = generate_response(input_text)
image = load_image(image_file)

plt.imshow(image)
plt.axis('off') 
plt.show()

print(f"Q:{input_text}")
print(f"A:{response}")
Q:What are things I should be cautious about when I visit this place?
A:When visiting this serene location with a pier extending over a lake, there are several things to be cautious about. 
First, be mindful of the weather conditions, as the pier is exposed to the elements, and sudden changes in weather can make it unsafe to be on the pier. 
Second, pay attention to the water level, as it may change throughout the day or due to natural factors like tides. 
This can affect the stability of the pier and the safety of visitors. Finally, be aware of any wildlife or aquatic life in the area, as they might pose a risk to your safety or the safety of others. 
It is essential to follow any guidelines or rules provided by the location to ensure a safe and enjoyable experience.</s>

日本語訳
Q: この場所を訪れる際に気をつけるべきことは何ですか?
A: この静かな場所には、湖に突き出た桟橋がありますが、訪れる際にはいくつか気をつけるべきことがあります。
まず、桟橋は自然の影響を受けやすいので、天気の状態に注意してください。天気が急変すると、桟橋にいるのは危険です。
次に、水位にも注意してください。水位は一日のうちや自然の要因、例えば潮の流れなどで変化する可能性があり、それが桟橋の安定や訪れる人々の安全に影響を与えることがあります。
最後に、その地域の野生動物や水生生物にも注意が必要です。それらがあなた自身や他人の安全を脅かす可能性があります。
場所の指南やルールを遵守することで、安全で楽しい経験ができるようになるのが大切です。

任意の画像の実装例

次に任意の画像を使用して実装してみます。

image_file:入力画像を指定します。以下の例では任意の画像のURLを指定しています。

input_text:画像に対する質問を入力します。

model_path = "liuhaotian/llava-v1.5-7b"
image_file = "https://i0.wp.com/tt-tsukumochi.com/wp-content/uploads/2022/01/drive.jpg"
input_text = "What should I be aware of when driving a car on this road?"

response = generate_response(input_text)
image = load_image(image_file)

plt.imshow(image)
plt.axis('off') 
plt.show()

print(f"Q:{input_text}")
print(f"A:{response}")
Q:What should I be aware of when driving a car on this road?
A:When driving a car on this road, you should be aware of the presence of multiple vehicles, including trucks and cars, as well as the possibility of pedestrians crossing the street. 
The image shows a busy highway with cars and trucks driving on the road, and there are pedestrians nearby.
 Therefore, it is essential to maintain a safe distance from other vehicles, be prepared for sudden stops or changes in traffic flow, and be vigilant for pedestrians crossing the street. 
Additionally, you should follow the speed limits and traffic rules to ensure a safe and smooth driving experience for yourself and others on the road.

質問: この道路で車を運転する際には、何に気をつけるべきですか?
回答: この道路で車を運転する際には、トラックや車など複数の車両の存在、および歩行者が道路を横断する可能性に気をつける必要があります。
画像には、道路を走行する車やトラックが多数映っており、近くには歩行者もいるようです。
そのため、他の車両と安全な距離を保ち、急な停車や交通の流れの変化に備え、歩行者が横断するのを注意深く見守る必要があります。
また、速度制限と交通ルールを守り、自分自身と他の道路利用者の安全でスムーズな運転を確保する必要があります。

まとめ

最後までご覧いただきありがとうございました。

今回の記事では画像からテキストに変換できるLLaVA-1.5の実装を紹介しました。

コメントを残す