このシリーズでは、自然言語処理において主流であるTransformerを中心に、環境構築から学習の方法までまとめます。

この記事では単語の分散表現の理解のため、前回紹介したWord2Vecと比較しながら、BERTによる単語のベクトル化の流れを紹介します。

Google colabを使用して、簡単に最新の自然言語処理モデルを実装することができますので、ぜひ最後までご覧ください。

【ChatGPT】自然言語処理まとめ【Huggingface Transformers】

自然言語処理に関するおすすめの書籍 ChatGPT ChatGPTを中心とした、GPT系の関連技術を紹介します。 ChatGPTの概要 ・ChatGPTとは・ChatGPTができること・ChatGPTの問題点…

今回の内容

・単語の分散表現とは

・BERTによる単語ベクトル

・単語・文章の類似度計算

単語の分散表現とは

単語を固定長のベクトルで表現することを「単語の分散表現」と言います。

自然言語処理において機械学習を活用するためには、単語の持つ性質や意味を反映したベクトル表現を獲得することが重要となります。

単語をベクトルで表現することができれば、単語の意味を定量的に把握することができるため、様々な処理に応用することができます。

単語をベクトルで表現する方法として、one-hotベクトル、Word2Vec、fastTextといった手法が提案されてきました。

【🔰自然言語処理】単語の分散表現とWord2Vec

このシリーズでは、自然言語処理において主流であるTransformerを中心に、環境構築から学習の方法までまとめます。 今回の記事では、単語の分散表現の概要と、Word2Vecの…

BERTによる分散表現

BERT(Bidirectional Encoder Representations from Transformers)は、2018年10月11日にGoogleによってが公開されました。

2017年頃から主流となったTransformer/Attention方式をベースにしており、穴埋め問題を解くような学習を加えることで精度が向上し、自然言語系の様々なタスクでより高精度を獲得しています。

このBERTを用いることで、1つの単語から複数の分散表現を獲得することが出来るようになリました。

BERTは双方向のTransformerが複数積層された構造になっており、ある単語を入力したとき、BERTのモデル内の重みが最も反映されるのは出力直前の最終層となります。

Transformerとは

概要

「Transformer」は2017年にGoogleが「Attention is all you need」で発表した深層学習モデルです。

現在では、自然言語処理に利用する深層学習モデルの主流になっています。

これまでの自然言語処理分野で多く使われていた「RNN」(Recurrent Neural Network)や「CNN」(Convolutional Neural Network)を利用せず、Attentionのみを用いたEncoder-Decoder型のモデルとなっています。

「Transformer」が登場して以降、多くの自然言語処理モデルが再構築され、過去を上回る成果を上げています。

最近では自然言語処理だけでなく、ViTやDETRなどといった画像認識にも応用されています。

詳細は以下の記事をご覧ください。

【🔰Huggingface Transformers入門④】 pipelineによるタスク実装紹介

このシリーズでは、自然言語処理において主流であるTransformerを中心に、環境構築から学習の方法までまとめます。 今回の記事ではHuggingface Transformersの入門として…

Huggingface Transformersとは

概要

「Hugging Face」とは米国のHugging Face社が提供している、自然言語処理に特化したディープラーニングのフレームワークです。

「Huggingface Transformers」は、先ほど紹介したTransformerを実装するためのフレームワークであり、「自然言語理解」と「自然言語生成」の最先端の汎用アーキテクチャ(BERT、GPT-2など)と、何千もの事前学習済みモデルを提供しています。

ソースコードは全てGitHub上で公開されており、誰でも無料で使うことができます。

事前学習済みモデル

Hugging Faceではタスクに応じた様々な事前学習済みモデルが提供されています。

こちらからご確認ください。

Google colabによる導入

ここからはGoogle colabを使用して実装していきます。

まずはGPUを使用できるように設定をします。

「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更

今回紹介するコードは以下のボタンからコピーして使用していただくことも可能です。

Open In Colab

まずはGoogleドライブをマウントして、作業フォルダを作成します。

from google.colab import drive 
drive.mount('/content/drive')
!mkdir -p '/content/drive/My Drive/huggingface_transformers_demo/'
%cd '/content/drive/My Drive/huggingface_transformers_demo/'
!git clone https://github.com/huggingface/transformers
%cd transformers

次に必要なライブラリをインストールします。

!pip install transformers[ja]

単語ベクトルを作成する

まずは日本語のBERTモデルをダウンロードします。

from transformers import BertJapaneseTokenizer
from transformers import BertModel
import torch
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
bert_model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

今回は、「男」という単語のベクトルを表示してみます。

word = "男"
input = tokenizer(word, return_tensors="pt")
outputs = bert_model(**input) 

BERTから最終層を取得して、出力してみます。

# BERTの最終層を取得
last_hidden_state = outputs.last_hidden_state

# 最終層のテンソルのshape
print(last_hidden_state.shape)
# 最終層のテンソル
print(last_hidden_state)
torch.Size([2, 8, 768])
tensor([[[-0.1707,  0.5272, -0.2290,  ..., -0.4597, -0.3741,  0.4881],
         [ 0.1445,  0.1414, -0.2429,  ..., -0.4400, -0.1156, -0.8959],
         [ 0.2496, -0.1137,  0.3556,  ...,  0.2444, -0.2035,  0.2558],
         ...,
         [ 0.8295, -0.3098, -0.2496,  ..., -0.3425, -0.0304,  0.2603],
         [ 0.2298,  0.3213,  0.4060,  ..., -0.3713, -0.0495,  0.4725],
         [-0.5595, -0.1565, -0.3251,  ..., -0.5645,  0.4300,  0.6169]],

        [[-0.2489,  0.3318, -0.5088,  ..., -0.2743, -0.3505,  0.4610],
         [-0.0141,  0.2046, -0.3064,  ..., -0.6348, -0.3773, -0.0186],
         [-0.4918, -0.0812,  0.0890,  ...,  0.0726,  0.2717,  0.0230],
         ...,
         [ 0.5321, -0.2691,  0.1183,  ..., -0.4126, -0.3916,  0.1003],
         [-0.0937,  0.6843,  0.6126,  ..., -0.1693, -0.3077,  0.4394],
         [-0.3022, -0.0349,  0.0524,  ..., -0.9029,  0.4028, -0.1226]]],

この最終層から[CLS]と[SEP]を取り除くと、単語ベクトルを得ることができます。

# 最終層から[CLS]と[SEP]のトークンを除いて表示
# BERTで作成した単語ベクトルのshape
print(last_hidden_state[0][1].shape)

# BERTで作成した単語ベクトル
print(last_hidden_state[0][1])

実行すると、以下のように出力されます。

torch.Size([768])
tensor([ 1.4451e-01,  1.4138e-01, -2.4285e-01, -4.5975e-01, -1.1265e-01,
         6.7216e-02,  1.1974e-01, -3.4235e-01, -6.2387e-02,  3.3364e-01,
        -3.1031e-01, -8.4823e-01,  1.2878e-02, -4.0749e-01,  5.6290e-01,
         7.8381e-02, -1.1832e-01,  1.8082e-01,  1.8435e-01, -5.0955e-01,
         1.2094e-01, -1.0211e+00,  7.6461e-02, -2.3979e-01, -2.5676e-01,
         1.2715e-01,  2.8308e-01, -1.4633e-01,  2.0107e-02, -6.9765e-02,
        -3.7051e-01,  1.2310e-01, -2.3917e-01, -5.9203e-01, -6.9198e-01,
        -6.4416e-02, -2.2473e-01,  2.2589e-01,  5.7423e-01, -6.7887e-01,
         3.4630e-01, -5.8065e-02, -2.9586e-01,  6.4134e-01,  5.2904e-01,
         8.1289e-01, -8.4706e-02,  3.6696e-01,  1.6657e-01,  1.7450e-01,
         4.2262e-01, -5.4678e-03,  5.4496e-01,  3.5917e-01, -8.6899e-02,
        -3.6932e-02, -1.6289e-01, -3.1452e-01,  1.4844e-01, -3.8578e-01,
         1.0130e-01,  5.6647e-03,  6.6274e-02, -7.5267e-02, -9.3391e-01,
        -5.7850e-01,  8.7778e-02, -5.7713e-01, -3.4379e-02,  1.0260e-01,
        -7.5210e-01,  3.2022e-01, -1.9498e-01, -1.3570e-01,  1.3469e-01,
        -7.7426e-01, -2.6365e-01, -2.9303e-01, -4.2807e-01, -3.4115e-01,
         3.1691e-01,  1.6166e-01,  4.0834e-01, -6.9600e-01, -3.0957e-01,
         2.8329e-01, -2.3042e-01,  2.6530e-01, -2.8333e-01,  1.0264e-01,
        -3.0689e-01, -7.3604e-02,  1.6900e-01,  6.9723e-01, -1.5057e-01,
         1.7735e-01,  2.7145e-02,  1.3420e-01, -3.6971e-02, -2.4895e-01,
         1.1866e-01, -3.5725e-01,  2.0201e-01, -1.4908e-01, -4.0753e-01,
        -2.2634e-02, -1.2005e-01, -8.0752e-01, -6.0205e-01,  4.6750e-02,
        -1.6181e-01,  2.9057e-01, -2.1352e-01, -7.8426e-02,  5.7340e-02,
        -9.6781e-02, -1.5343e-01, -4.9069e-01, -3.1009e-02,  1.8596e-02,
        -1.2099e-01, -4.1850e-01, -1.9230e-01,  3.5130e-01, -4.4821e-01,
         2.5470e-01,  2.8373e-01,  1.8774e-01, -3.7589e-01, -1.8130e-01,
         1.0670e-02,  4.3980e-01,  2.0733e-01, -2.2607e-01, -6.0236e-03,
         2.0308e-01, -2.5505e-01, -1.3493e-01, -7.6622e-01, -1.6122e-01,
         3.1110e-01, -7.0250e-01,  1.6210e-01,  4.4559e-01, -6.3578e-02,
        -4.0638e-01, -2.0018e-01,  7.9215e-02,  3.2638e-01, -2.7217e-01,
        -1.4518e-01, -4.5701e-01,  7.8500e-01,  6.9186e-03, -1.4304e-01,
        -8.2297e-01, -6.0059e-01,  5.4975e-01,  7.1029e-02, -1.6044e-01,
         3.7533e-01, -4.9596e-01,  1.9321e-01, -2.3004e-02,  7.2436e-01,
        -3.8982e-03, -3.8308e-01, -2.1298e-01,  5.7227e-01, -2.9059e-01,
        -3.6694e-01,  2.1090e-01,  5.4300e-02,  2.6925e-01,  4.8504e-01,
        -3.3862e-01,  6.0748e-02, -5.7854e-02, -5.9217e-01,  4.4300e-01,
         9.5029e-02,  1.8161e-01,  2.6650e-01,  6.0550e-01, -2.0564e-02,
         2.5669e-01, -5.3899e-01,  1.8120e-01,  2.8302e-01,  2.6461e-01,
         3.1452e-01,  4.0307e-01,  1.2081e-01, -4.4754e-02, -5.6783e-01,
        -2.7386e-01, -2.7264e-01, -7.9274e-02,  1.0331e-01,  1.4307e-01,
         3.2411e-02, -9.0872e-02,  2.1397e-01,  5.5911e-01, -1.2616e-01,
         7.4664e-01, -3.9308e-02,  3.1403e-01,  2.1716e-01, -1.2542e-01,
        -3.9923e-02, -2.7066e-01, -1.5374e-01,  7.4528e-01, -1.8526e-02,
         9.3740e-02,  1.4399e-01, -2.5653e-01, -1.6716e-01,  8.6998e-02,
        -5.7568e-02, -4.4752e-01, -1.2513e-01, -5.1140e-01,  3.2167e-01,
        -8.5383e-01, -3.8914e-01,  2.6121e-01,  4.7042e-01,  1.3815e-01,
         1.3692e-01,  2.9005e-01, -3.1835e-01, -3.2770e-01,  5.2344e-01,
        -9.4294e-02, -2.6733e-01,  4.6628e-02,  3.4508e-01, -5.5710e-01,
        -6.6834e-01,  1.8798e-01, -1.7342e-01, -3.6999e-01,  1.3351e-01,
         9.9377e-01, -6.2528e-01,  3.0140e-01, -1.4834e-01, -3.0960e-01,
         1.9443e-01,  6.9328e-01, -7.7440e-02, -7.6267e-01, -8.0779e-01,
         1.6347e-01,  5.0165e-02,  3.0094e-01,  3.9797e-01,  5.9790e-01,
        -3.0307e-01, -3.4906e-02, -5.9387e-02, -1.0329e-01, -8.2852e-01,
        -2.2473e-01,  2.6110e-01,  4.7506e-01, -3.9283e-02,  5.6872e-01,
        -1.8706e-01,  2.2601e-01,  4.0102e-01,  3.1131e-02,  2.8757e-01,
         2.8571e-01,  5.8021e-01, -4.2791e-01,  7.8269e-03,  4.0677e-01,
        -2.7664e-01,  2.2835e-01,  2.3626e-01, -6.3141e-02,  1.1153e-01,
        -4.7822e-01, -1.5301e-01, -1.8227e-01,  3.7414e-01,  7.8456e-01,
         2.2415e-01, -1.7285e-01,  6.6482e-01, -5.3127e-01,  2.0935e-01,
        -2.5288e-01,  1.0435e-01,  4.0029e-01, -4.3612e-01,  8.8531e-02,
         1.5028e-02,  5.2645e-01, -1.8120e-01, -1.7755e-01, -5.3257e-01,
         1.3296e-01,  4.0218e-01, -2.3729e-01,  2.8322e-02, -4.7788e-01,
        -8.5450e-01,  1.1363e+00,  6.2536e-02, -2.5385e-01,  3.5211e-01,
         5.1479e-01, -5.3689e-02, -1.2642e-01, -3.7007e-01,  5.5000e-01,
         4.3096e-02,  7.3816e-02, -3.2128e-02, -7.2316e-02,  2.2236e-01,
         3.5409e-01, -3.1336e-02, -5.6832e-01,  5.4012e-01,  1.4459e-01,
         5.5330e-01, -8.9394e-02,  2.4170e-01, -2.1423e-01, -3.7454e-01,
        -3.4503e-01, -1.0431e+00, -4.3838e-01, -9.7780e-02,  1.4912e-01,
        -3.3920e-01,  5.0595e-01,  1.9275e-01,  2.9721e-01, -2.8312e-01,
         4.9051e-02, -1.8098e-01,  2.1835e-01,  2.9131e-01, -3.0268e-01,
         8.0726e-01, -6.6839e-02,  1.5916e-01,  1.9950e-01,  6.9302e-01,
        -7.4939e-01,  2.2071e-01, -3.6446e-01,  1.5311e-01, -2.2051e-02,
         2.4524e-02, -1.4302e-01, -1.4640e-01,  4.4812e-01,  5.3696e-02,
         7.9950e-02, -1.5195e-01, -5.8127e-01,  1.4290e-01, -2.0773e-01,
        -2.0873e-01, -7.6064e-02,  3.9885e-01, -4.0054e-01, -5.7849e-02,
        -6.6745e-02, -4.2272e-01,  1.2046e-01, -1.8130e-01, -2.6026e-02,
         3.2674e-03,  8.5135e-02, -4.1215e-01,  2.8920e-01,  2.4040e-01,
        -2.0729e-01, -8.7971e-01,  3.0972e-01,  3.6099e-01,  4.8119e-02,
         5.6250e-01, -3.5250e-02,  3.1775e-01,  2.0204e-01, -4.9827e-01,
        -1.7355e-01,  1.4027e-01, -3.0070e-01,  4.1554e-01,  3.2312e-02,
         5.9653e-01, -1.0700e-02, -1.8588e-01,  8.2568e-02,  6.7162e-01,
        -8.9231e-02,  4.5655e-01,  3.9033e-01,  1.8702e-01,  2.4272e-01,
        -1.5443e-01,  8.7487e-02, -1.8133e-01,  1.7933e-01,  4.5695e-02,
        -4.2245e-02,  1.4209e-01, -1.2813e-02, -6.1114e-01, -8.4514e-01,
         6.8840e-01,  1.3905e-01,  3.2074e-01, -1.8220e-01, -2.3588e-01,
         1.5011e-01, -8.7499e-02,  3.2523e-01,  6.1131e-01, -2.2579e-01,
        -6.7687e-01,  3.0260e-01,  2.4035e-01, -4.7868e-01,  1.7862e-01,
         1.7601e-01, -3.6321e-02, -2.1918e-01, -7.0772e-01,  1.6870e-01,
         1.7704e-01,  2.8989e-01, -4.7155e-01,  3.1615e-01,  2.2111e-01,
         4.8321e-01, -4.0633e-01,  2.0490e-01, -7.2281e-01,  1.1786e-01,
         2.0533e-01, -2.1710e-01, -3.1785e-02, -2.3179e-01, -3.3842e-01,
         1.1463e-01,  2.5344e-02, -1.1851e+00,  4.3171e-01, -3.5364e-01,
        -5.3355e-01, -4.0461e-01,  8.5380e-02,  7.3903e-02, -3.5577e-03,
        -1.9990e-02,  1.4921e-01, -8.3527e-04, -1.9013e-02,  2.1174e-01,
        -2.5005e-01, -4.2794e-01,  1.7412e-01,  5.1189e-01, -8.4861e-02,
         8.6171e-02, -3.9187e-01,  9.0300e-01, -5.5342e-02, -1.0177e-01,
        -2.8523e-01, -3.5686e-01, -9.9278e-02, -2.5441e-01,  3.7523e-01,
         2.2632e-01, -5.3697e-01,  6.9617e-02,  8.3040e-02, -6.3463e-03,
        -1.6970e-01, -3.2942e-01, -9.6454e-01,  1.1130e-02,  7.3840e-02,
         1.5068e-01, -1.9424e-01,  6.2504e-01,  2.9781e-01, -6.8875e-02,
         5.3533e-01, -3.0823e-01, -3.4497e-01, -1.6733e-01,  3.3012e-01,
         3.8747e-01, -2.3218e-01, -3.3858e-01,  3.3284e-01,  2.0404e-01,
        -7.0364e-01, -1.9036e-01, -4.6530e-01, -2.2199e-01, -1.6202e-01,
         4.1600e-01, -7.8640e-03,  1.8722e-01,  2.8484e-01, -2.2681e-01,
         2.2418e-01,  3.3265e-01, -1.1120e-01,  3.4069e-02,  4.4062e-02,
         5.9258e-02, -4.0374e-01, -4.6401e-01,  3.6902e-01, -2.7568e-01,
         1.8221e-01,  6.8858e-01, -3.7736e-01,  2.0455e-01,  2.3942e-01,
         5.2718e-01, -4.9345e-01, -1.4777e-01, -2.8008e-02, -4.0523e-02,
         3.5261e-01,  1.9200e-01, -1.9465e-01, -1.2883e-01,  1.5956e-01,
         3.9144e-02, -7.8584e-02,  1.6635e-01,  6.7698e-03, -4.4800e-03,
        -4.0468e-01,  2.0733e-01, -2.1623e-01, -5.3754e-01, -4.0643e-01,
        -1.6648e-01,  9.4148e-02,  1.9645e-01,  7.0336e-02, -7.7823e-02,
         2.5997e-01,  2.7050e-01, -1.1511e-01, -8.8562e-02,  1.2231e-01,
         9.2092e-02,  3.5043e-01, -2.1222e-01,  1.1809e-01,  5.6370e-01,
        -1.6927e-01, -6.7250e-02, -2.6679e-01,  1.8438e-01, -1.0437e-01,
        -1.0876e-01, -3.3285e-02, -1.6644e-01,  1.1961e-01, -1.6794e-01,
        -5.9850e-02,  5.7083e-01, -1.1013e-01, -4.9922e-01, -3.6749e-02,
        -3.3413e-01, -2.1030e-02,  6.6338e-02,  3.3481e-01,  1.8012e-01,
         2.6591e-01,  6.5587e-01,  6.0494e-02, -2.7890e-02,  7.9517e-02,
         1.9820e-01,  3.9939e-02, -4.0796e-02,  3.7233e-01, -4.1430e-02,
         1.6885e-01,  1.5520e-01,  1.6389e-01, -1.3510e-02, -2.7415e-01,
         3.0349e-01, -3.4330e-01, -1.8475e-01,  1.2336e-01,  5.9527e-02,
         5.8737e-01, -1.2942e-02, -3.7399e-02, -1.6238e-01,  2.0556e-01,
        -5.3172e-01, -1.4219e-01,  9.9291e+00,  6.3678e-01,  6.1553e-02,
        -2.0100e-01,  4.7839e-01, -7.0796e-01, -4.5690e-01, -1.2833e-01,
        -3.3865e-01,  6.7796e-01, -1.9700e-01,  2.4783e-01, -4.1169e-01,
         2.1225e-01, -4.0976e-01,  3.8130e-02,  3.8248e-01,  1.1482e-01,
         5.8699e-01, -2.9433e-01, -4.0858e-01, -3.2896e-01,  2.1210e-01,
        -6.4544e-01,  1.0598e-01, -2.9047e-01, -5.4636e-01,  6.6889e-01,
        -1.8817e-01, -4.9800e-01,  3.8094e-01, -1.4745e-01,  6.5264e-02,
         1.2921e-01, -1.5971e-01, -1.3207e-01, -2.8490e-01, -3.9678e-02,
        -7.1823e-02,  5.3187e-01, -1.8192e-01, -1.4276e-01, -3.1255e-01,
        -6.8351e-02,  1.6442e-02,  1.8553e-01,  5.3105e-02,  2.9920e-01,
        -1.4007e-01,  2.2026e-01, -7.0759e-02, -2.7410e-01, -2.1115e-01,
        -2.8701e-01, -6.8378e-02,  1.6862e-01, -1.2811e-01,  6.0629e-01,
        -1.8978e-01,  4.8770e-01,  4.8323e-02,  7.0020e-02, -4.2045e-01,
         2.7926e-01,  2.3637e-01, -3.3293e-01, -7.2897e-01,  8.7707e-02,
         1.0672e-01,  2.3620e-01, -4.0640e-02, -4.6251e-01, -9.2714e-02,
        -1.0523e-01, -6.1449e-01,  1.2759e-01, -1.9725e-01,  1.2620e-01,
        -2.3726e-01,  2.3094e-01, -8.8973e-01,  3.0535e-01,  4.4807e-01,
        -1.2142e-01, -1.2354e-01,  5.7408e-02, -9.1067e-02,  1.3050e-01,
         4.2765e-01, -2.2066e-01, -1.2292e-01,  3.9081e-02,  1.0836e-01,
         2.7527e-01, -3.4867e-01,  1.6575e-01,  1.3522e-01, -6.6203e-01,
         7.3601e-01,  4.4133e-02,  5.4364e-01,  1.2520e-01, -6.7164e-01,
        -5.2348e-01, -4.2528e-02,  2.2360e-03, -3.7095e-01,  5.5230e-01,
        -4.2800e-02,  3.0577e-01, -5.9843e-02,  3.2949e-01,  6.5309e-01,
         1.0016e-01, -2.5682e-01,  1.5163e-01, -1.6559e-01,  5.9465e-01,
        -2.4192e-01,  2.8894e-01, -5.4235e-02,  1.4259e-01,  1.0014e+00,
        -1.5320e-01,  2.1810e-01, -3.8064e-02, -2.5395e-02,  3.9258e-01,
         6.8331e-01, -2.5370e-01,  2.0085e-01,  2.2965e-02, -1.0974e-01,
         3.8632e-01,  1.4395e-01,  4.1374e-01,  8.5952e-01, -3.3775e-01,
         5.7560e-01, -2.8207e-01,  9.2888e-02, -2.0208e-01,  1.6092e-01,
         1.9409e-01, -2.8578e-01, -2.2457e-01, -2.2900e-01, -2.1858e-01,
        -4.3997e-01, -1.1557e-01, -8.9590e-01]

BERTを使用して、単語ベクトルを得ることができました。

トークナイズする必要がある点と、最終層から特殊トークンを取り除く必要がある点を除けば、基本的な流れは前回のWord2Vecと同じとなります。

文ベクトルを作成する

次に文ベクトルを作成します。

Word2Vecでは形態素解析を行いましたが、BERTではトークナイザーにより行われるため、これらの操作が不要となります。

今回使用するモデルでは、形態素解析だけでなく、Word Poeceでサブワード化しています。

以下の文のベクトルを出力します。

text = "私はラーメンが大好きです。"

from transformers import BertJapaneseTokenizer
from transformers import BertModel
import torch
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')
bert_model = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

bert_model.to("cuda")

input = tokenizer(text, return_tensors="pt")

input["input_ids"] = input["input_ids"].to("cuda")
input["token_type_ids"] = input["token_type_ids"].to("cuda")
input["attention_mask"] = input["attention_mask"].to("cuda")

outputs = bert_model(**input)
last_hidden_states = outputs.last_hidden_state

attention_mask = input.attention_mask.unsqueeze(-1)
valid_token_num = attention_mask.sum(1)

sentence_vec = (last_hidden_states*attention_mask).sum(1) / valid_token_num
sentence_vec = sentence_vec.detach().cpu().numpy()[0]

# BERTで作成した文ベクトルのshape
print(sentence_vec.shape)

# BERTで作成した文ベクトル
print(sentence_vec)

実行すると、以下のような結果が出力されます。

(768,)
[ 1.16645016e-01 -2.69998480e-02  1.11215711e-01 -1.76774919e-01
 -9.13642347e-02  2.75651634e-01 -2.35511616e-01 -1.93148434e-01
 -7.93425813e-02 -2.88740601e-02 -2.96930283e-01 -2.21637666e-01
 -1.40309021e-01 -1.87762529e-01  1.58262357e-01  9.86073241e-02
 -2.56132364e-01  3.72660071e-01 -3.91580574e-02 -3.66479605e-01
 -1.04905486e-01 -4.29972559e-01  4.19758894e-02 -2.34562978e-01
 -2.32449770e-01 -6.13280535e-02  3.81922036e-01 -5.08529246e-01
 -2.35810187e-02 -1.35561928e-01  1.87966079e-01  1.48832232e-01
 -8.84959474e-03 -1.97303146e-01 -2.93785125e-01 -5.82409836e-02
  1.30212888e-01  2.51266271e-01  5.33977211e-01 -3.55691552e-01
  3.87109280e-01 -2.33514518e-01 -5.39802909e-01  4.94451523e-01
  3.90039161e-02  1.59935966e-01  1.42281771e-01  2.18498230e-01
  1.06563598e-01  3.10081482e-01  4.75322932e-01  4.04897332e-01
  2.65945271e-02  9.35317576e-02  1.15121126e-01  1.23663753e-01
 -2.02995524e-01  1.92660876e-02  2.90293425e-01  6.13033446e-03
  9.28239450e-02 -8.71070996e-02  1.84285641e-01 -3.56759936e-01
 -3.58550310e-01 -6.68543100e-01  4.71717000e-01 -3.91031712e-01
  2.12018386e-01  3.54178064e-02 -8.94974358e-03 -7.73456022e-02
 -1.86538436e-02 -3.33743483e-01 -7.46299699e-02 -3.70936185e-01
  4.34862264e-02 -1.47453949e-01 -2.90074468e-01 -5.15955538e-02
  8.46547540e-03  1.53690308e-01 -6.82750158e-03 -1.18857183e-01
 -1.24702536e-01 -1.26112700e-01 -1.14122592e-01  1.28197432e-01
 -1.76809579e-01  2.84796238e-01 -7.16832876e-01  1.37197748e-01
  4.62603211e-01  2.10855961e-01 -3.52002442e-01 -3.78438868e-02
 -1.49118667e-02  3.28039616e-01  2.39827052e-01 -5.37532903e-02
  2.22042069e-01  1.02910332e-01 -1.94547713e-01  2.78615683e-01
  1.25530854e-01 -7.89328385e-03 -1.39482051e-01 -3.06482315e-01
 -5.79784095e-01  1.60937473e-01 -1.25482127e-01  2.67637551e-01
 -9.12056118e-02 -2.00002491e-01 -4.89119776e-02 -1.08313352e-01
  1.46988695e-02 -3.53902936e-01  4.23455089e-02 -1.62165806e-01
  1.40364349e-01 -2.44031250e-01  2.35578954e-01 -1.52168825e-01
 -1.00826196e-01  3.04611593e-01  2.02752948e-01  4.76849526e-02
 -3.58862132e-01  3.40730220e-01  1.37777209e-01 -1.20538153e-01
  2.70074517e-01 -2.79610921e-02 -1.22452632e-01 -2.42657974e-01
 -4.03084755e-01  3.36922884e-01 -4.64601368e-01 -1.70622483e-01
 -7.23800138e-02 -3.57068181e-01 -7.71288648e-02  3.56483817e-01
 -2.92503387e-01 -1.53763831e-01  2.66874582e-01  5.06658778e-02
  3.78775865e-01 -1.34048507e-01 -2.79095769e-01 -1.88842013e-01
  4.48526174e-01  3.61919969e-01  3.42295691e-02 -6.75102592e-01
 -1.84596211e-01  5.55170476e-01  4.58499014e-01 -2.33504236e-01
  2.19830930e-01  2.96626482e-02 -1.46871373e-01  1.51041895e-01
  4.17462081e-01  1.59212530e-01 -1.29101232e-01  4.66660976e-01
  2.38560379e-01  6.91679567e-02 -1.12132221e-01  9.70969871e-02
  1.90168306e-01  3.56654614e-01  1.93869218e-01  4.12869044e-02
 -2.14106634e-01  2.77724773e-01 -8.33107904e-02  2.23763764e-01
 -1.71429247e-01 -7.35957921e-02  1.75299540e-01 -8.60217661e-02
 -2.08960935e-01  2.32195273e-01 -3.90085489e-01 -2.13124916e-01
 -4.23903205e-02  7.66417682e-02 -9.95024294e-02  2.90966332e-01
  3.05152386e-01  3.85043025e-01 -8.91811699e-02 -8.94122645e-02
 -3.95992398e-01  1.96703970e-01  2.71384157e-02  4.38183576e-01
  3.49755496e-01 -2.29956254e-01  1.53006494e-01  3.03208649e-01
  2.59880852e-02  4.88161206e-01 -1.89309016e-01 -1.53463259e-01
 -8.77967104e-03  6.86315596e-02 -1.07813098e-01 -1.94540635e-01
  5.67885637e-02  4.39971149e-01  1.09706633e-01  6.23155572e-02
 -1.80192426e-01 -1.99464291e-01  1.89636737e-01 -2.76034069e-03
 -5.08680880e-01 -3.55161518e-01 -1.03138849e-01 -3.42920184e-01
  2.25148991e-01 -4.01016504e-01  1.46824405e-01  3.02794218e-01
  2.88044930e-01  1.85519755e-01 -9.54145119e-02  4.23318118e-01
 -3.01135123e-01 -4.88857210e-01  4.00173925e-02  1.28351897e-01
 -1.45984784e-01  3.68875206e-01  2.68858731e-01 -5.33505559e-01
 -4.47764814e-01 -1.53470650e-01 -9.99481678e-02  1.48785517e-01
 -7.17079043e-02  6.41162157e-01 -4.55468595e-01  3.69701266e-01
 -1.26544788e-01 -3.30194473e-01  1.68681219e-01 -9.54824314e-02
  6.11844733e-02 -9.10399780e-02 -6.40117228e-02  2.14917600e-01
  4.06882465e-01  7.05398321e-01  9.72134899e-03 -9.47078050e-04
 -1.19935781e-01  3.28014232e-02 -5.42141050e-02 -2.49437734e-01
 -1.35806918e+00 -1.05776470e-02  3.25645469e-02 -6.86758459e-02
 -2.93965518e-01  1.88091651e-01  1.61775276e-01  9.44367126e-02
 -1.81450427e-01  1.07504874e-01 -4.20095474e-02  7.72337690e-02
  6.76210642e-01 -5.27345002e-01  1.42418846e-01 -3.77083085e-02
 -3.27575475e-01  2.64490753e-01 -3.70825790e-02 -3.02431196e-01
  3.07791471e-01 -2.45926693e-01  1.50560796e-01  8.37669522e-02
  2.28280514e-01  1.55672356e-01  1.62662104e-01 -6.42699838e-01
  1.74161196e-01 -3.18246782e-01 -1.19601198e-01  4.86125499e-02
  1.81635544e-01  2.11371601e-01 -4.72254902e-02  1.15457565e-01
  3.92386429e-02  6.49947003e-02 -2.26003334e-01 -3.04925978e-01
 -1.83560122e-02  4.39794958e-01 -2.63295531e-01  4.86067012e-02
  7.75686875e-02 -2.28289664e-01 -2.99886495e-01  2.69457012e-01
  1.14996113e-01 -4.33000401e-02  2.49553546e-01  1.12827748e-01
  1.06662594e-01  6.71390593e-02 -4.91715431e-01  4.22318816e-01
 -2.19085723e-01 -3.94940600e-02  2.27570057e-01  9.27821025e-02
  2.59276223e-03  4.65669215e-01  5.48544712e-02 -3.16387296e-01
 -1.36298627e-01 -4.64232802e-01  1.13320261e-01  2.14772642e-01
  3.79378676e-01  2.59406716e-01 -3.15208554e-01 -8.10364783e-02
 -5.81594706e-01  1.19628206e-01 -1.01030231e-01  2.08956867e-01
 -7.97904730e-02 -1.17656058e-02  3.93874973e-01 -3.38722110e-01
  1.33708820e-01  1.64811060e-01  1.88624322e-01  5.08403927e-02
  1.93324760e-01  1.71604380e-02  1.01251835e-02  4.37014140e-02
  1.54320776e-01 -3.44331592e-01  2.43144780e-01 -3.03128600e-01
  2.35194385e-01 -3.47699344e-01  5.47997691e-02 -1.98522270e-01
 -7.38824233e-02 -1.51117146e-01 -4.96245958e-02 -1.17577992e-01
  3.59460086e-01  1.07492909e-01 -7.82146025e-03 -1.57131985e-01
 -1.98000565e-01 -2.15404332e-01 -2.62553245e-01  7.79210627e-02
  4.70942140e-01  2.81424820e-03  2.70506572e-02 -5.72257256e-03
 -7.16906190e-02  2.57993013e-01 -3.42219211e-02 -2.59114623e-01
  8.24010521e-02  2.43669748e-01 -1.51530161e-01  3.19538742e-01
  3.65088701e-01 -1.25003457e-02 -8.96904647e-01 -1.39962863e-02
  6.72172844e-01 -3.02417845e-01  2.93942779e-01 -1.23946510e-01
  2.00408757e-01 -1.25391886e-01 -3.13640893e-01  4.55936529e-02
  2.93015838e-02 -2.83155084e-01  4.82502043e-01 -1.88057795e-01
  4.76656228e-01  7.30441064e-02  1.35413483e-01 -3.41264904e-03
  6.72634363e-01 -2.81704664e-01  7.62593327e-03  7.87863851e-01
 -2.45893821e-01  2.11643577e-01 -1.84731707e-02  9.80375037e-02
 -5.01222908e-01  2.39692748e-01  3.04066867e-01 -6.81792498e-02
  1.14222102e-01  9.54146758e-02  2.24780440e-02 -3.55425179e-01
  5.53421557e-01 -2.25330412e-01 -1.62246227e-01  1.24709994e-01
 -3.49792212e-01 -3.20938647e-01 -3.05978984e-01  1.37147397e-01
  3.28983188e-01 -3.13005507e-01 -5.79362452e-01  7.01204017e-02
  2.64860034e-01 -1.08766511e-01  1.85283959e-01  6.05280161e-01
  9.76110995e-02  5.24522290e-02 -3.01347792e-01  1.67150840e-01
  1.73343286e-01  1.32637158e-01 -5.30415535e-01 -1.76462010e-02
  2.12263033e-01 -2.02660590e-01 -8.82629305e-02 -7.55758137e-02
 -4.16431576e-01 -1.16700836e-01 -1.85223706e-02 -4.44327146e-02
  3.97494316e-01 -4.47231829e-01 -2.17226475e-01  2.21661150e-01
  5.72672300e-02 -5.83270788e-01 -8.96578208e-02  1.84886903e-01
  3.19760829e-01  2.92177945e-01  3.60193625e-02  2.13162795e-01
 -1.49096757e-01 -4.20999557e-01 -3.13207090e-01  1.38702512e-01
 -1.53461546e-01 -4.53067571e-01  2.01021165e-01  2.49826774e-01
  2.00154722e-01 -2.77289264e-02 -6.18494861e-03  5.15535064e-02
 -3.34012300e-01  1.95865974e-01 -4.59583662e-02 -3.24911982e-01
 -3.07103008e-01 -1.39998168e-01 -2.07749993e-01 -1.39559479e-03
  1.03621431e-01  1.92874506e-01 -4.59644794e-01  3.02932620e-01
  2.85163045e-01 -1.32697327e-02 -2.34646499e-01 -2.20062837e-01
 -7.30955243e-01 -7.56781399e-02 -6.09119594e-01  1.17783755e-01
  4.31181043e-02  3.92201036e-01  3.07408988e-01 -2.94495732e-01
  4.80726659e-01  5.56073859e-02  2.20415071e-02 -2.92141467e-01
 -6.71224967e-02 -1.63323641e-01 -7.38964528e-02 -2.20856488e-01
 -7.56525621e-02  4.60971780e-02 -3.94452989e-01 -3.73278596e-02
 -4.84836847e-01 -1.26130104e-01 -2.56763567e-04 -3.21802527e-01
  1.19481459e-01  3.57518613e-01 -7.74687529e-03  3.25354077e-02
  2.21136823e-01  2.20172933e-05  6.32953048e-02 -3.83611888e-01
  3.53758007e-01  2.09923655e-01 -1.96721107e-01 -4.14283395e-01
  1.14875399e-01 -2.09690124e-01  1.76390469e-01  2.99846888e-01
  2.26691887e-01 -5.16504049e-04 -1.39278248e-01  4.71806787e-02
 -3.81502479e-01 -2.51010835e-01 -1.77373275e-01 -4.22858119e-01
 -7.20922053e-02 -2.45308187e-02  2.18189552e-01 -3.65441352e-01
  3.33130211e-01  2.58652240e-01 -1.01695485e-01  7.31680021e-02
  2.49111041e-01 -2.33805582e-01  8.63792449e-02  1.46839544e-01
  1.25706494e-01  3.44739318e-01 -1.14470996e-01 -1.13352403e-01
 -3.61290336e-01  5.24032786e-02  1.80748031e-01 -4.08802480e-02
  2.47335225e-01 -2.15142265e-01 -1.17292844e-01  2.61807084e-01
 -1.83695585e-01 -1.62501168e-02  1.07439823e-01 -2.36829259e-02
 -3.20763707e-01  2.04202145e-01  2.67726388e-02  6.64610416e-02
  3.40080000e-02  1.37332857e-01  2.64618360e-02  5.79836443e-02
  3.59623022e-02 -1.23792015e-01 -1.07046694e-01  1.03229374e-01
 -1.86888874e-01  1.69912025e-01 -1.89462736e-01 -1.21841684e-01
 -7.31310924e-04 -4.24594820e-01 -2.33563930e-01 -3.78941162e-03
 -1.40226975e-01  3.98322523e-01  2.22769395e-01  1.39662206e-01
  1.66048229e-01  1.63753390e-01  3.25269282e-01  6.44879267e-02
  4.80579324e-02 -2.31674328e-01  3.73645008e-01  2.38386691e-01
  9.57276449e-02 -9.42394733e-02  2.11613812e-02  1.53915389e-02
 -1.82077438e-01  6.56224787e-01 -2.21424073e-01 -2.57806480e-02
  2.16008186e-01 -7.49451816e-02  5.80933616e-02 -1.50165483e-01
 -1.73770458e-01  2.85280682e-02  1.24789782e-01 -9.78413522e-02
 -1.50758579e-01  9.10742188e+00  2.54466146e-01 -9.90526974e-02
 -2.54162937e-01  2.77779043e-01 -6.36898160e-01 -1.76599920e+00
  2.71020532e-01 -2.25932539e-01  5.32595068e-02 -5.01945555e-01
  1.78606674e-01 -2.00601116e-01  2.05235630e-01 -1.53589398e-01
  9.35979187e-02  2.46491373e-01  5.14198020e-02  1.09973758e-01
 -2.30385229e-01 -5.45648992e-01 -3.18401277e-01 -2.48935387e-01
 -1.16927303e-01  4.56908107e-01 -5.19649349e-02  5.24568895e-04
  3.72924358e-01  1.34456843e-01 -4.43807095e-02  5.49804121e-02
 -4.21875954e-01  7.65960570e-03  2.58740515e-01  5.88470697e-02
  8.15215632e-02 -3.35645437e-01 -3.05762768e-01  5.43732103e-03
  1.44582435e-01 -7.03126788e-01 -1.59444749e-01  1.02969848e-01
 -3.82062554e-01  4.39467609e-01 -7.93936789e-01 -2.26255104e-01
  2.60962576e-01 -2.28940044e-02  3.63110495e-03  1.50639981e-01
 -2.46939898e-01  4.85761225e-01 -2.34103307e-01 -3.67841795e-02
 -6.33109137e-02 -1.35341091e-02  2.97785878e-01  5.06722510e-01
  1.75929368e-01  1.08785182e-02  7.92853162e-02 -1.44113936e-02
  1.17343426e-01 -2.46235356e-02 -1.84818536e-01 -5.25117069e-02
 -7.16090947e-02  3.59278888e-01 -7.79941902e-02 -3.15954387e-02
 -2.43737489e-01 -1.63554505e-01  8.99142697e-02  5.16833588e-02
 -5.06261289e-02  1.51773125e-01 -2.35778261e-02 -9.44307521e-02
  2.09402665e-02 -3.00347835e-01 -8.72449055e-02  2.81572014e-01
 -4.70810980e-02 -5.97097635e-01  2.51912549e-02  1.48506865e-01
  2.27088451e-01  2.53861010e-01 -1.46152182e-02  7.25240260e-02
  5.73300272e-02  6.98813051e-02 -7.72519410e-02 -3.90465319e-01
 -2.53465418e-02 -4.79814261e-01 -5.13032317e-01  1.99158341e-01
  1.36163188e-02  6.57034993e-01  1.83930695e-01 -3.12830448e-01
  1.05887121e-02 -1.99639261e-01 -1.29945979e-01 -2.20157802e-01
  4.66875285e-01 -2.86045015e-01  1.04514316e-01 -1.94038540e-01
 -5.52563220e-02  3.72509539e-01  9.64507759e-02  2.99613699e-02
  4.17727411e-01 -1.28201455e-01  3.00422430e-01  5.63987643e-02
  2.11416289e-01  1.17464233e-02  7.47889802e-02  8.27897787e-01
  1.53330058e-01  1.70737922e-01 -3.92578393e-01 -1.71654239e-01
  7.83131421e-02  2.52156854e-01 -5.63733995e-01 -2.73893476e-02
  1.20832272e-01 -1.11815281e-01  6.92681372e-02  9.43027511e-02
  1.23949900e-01  2.12450743e-01 -5.26038595e-02  2.57368356e-01
 -1.45163834e-01  1.76537365e-01 -1.89996943e-01  3.64017636e-02
  2.75990963e-01  5.15742898e-02  1.01273786e-02 -4.29876447e-01
  1.69882640e-01 -2.65017390e-01  7.23594800e-02  2.55487144e-01]

単語と同様に、文のベクトルを出力することができました。

一般には、複数の文で構成された「文章」を扱う場合が多いことが想定されるため、[CLS]と[SEP]トークンを取り除くことなく、そのまま利用しています。

文章の類似度計算

最後に2つ文章の類似度を算出してみます。

Word2Vecの際と同様に、以下の文章で類似度で計算してみます。

sentences = ["私はラーメンが好きです","チャーシューメンが好きです"]

input = tokenizer(sentences, return_tensors="pt",padding=True,truncation=True)

input["input_ids"] = input["input_ids"].to("cuda")
input["token_type_ids"] = input["token_type_ids"].to("cuda")
input["attention_mask"] = input["attention_mask"].to("cuda")

outputs = bert_model(**input)
last_hidden_states = outputs.last_hidden_state
attention_mask = input.attention_mask.unsqueeze(-1)
valid_token_num = attention_mask.sum(1)
sentence_vecs = (last_hidden_states*attention_mask).sum(1) / valid_token_num
sentence_vecs = sentence_vecs.detach().cpu().numpy()
from numpy import dot
from numpy.linalg import norm

similarity_with_bert = dot(sentence_vecs[0], sentence_vecs[1]) / \
                       (norm(sentence_vecs[0])*norm(sentence_vecs[1]))

print(similarity_with_bert)
0.9134689

2つの類似度を計算することができました。

まとめ

最後までご覧いただきありがとうございました。

この記事では単語の分散表現の理解のため、前回紹介したWord2Vecと比較しながら、BERTによる単語のベクトル化の流れを紹介しました。

このシリーズでは、自然言語処理全般に関するより詳細な実装や学習の方法を紹介しておりますので、是非ご覧ください。

【ChatGPT】自然言語処理まとめ【Huggingface Transformers】

自然言語処理に関するおすすめの書籍 ChatGPT ChatGPTを中心とした、GPT系の関連技術を紹介します。 ChatGPTの概要 ・ChatGPTとは・ChatGPTができること・ChatGPTの問題点…