概要
今回はPyTorchにおけるTensor(テンソル)の概要について紹介します。本記事を通して、TensorとはどのようなものかTensorの概要と基本的なTensorの扱い方をお伝えします。
Tensorとは?
Tensorとは、スカラー、ベクトル、行列の総称です。例えば、自然言語処理では単語をベクトルにとして扱い、画像処理ではRGB画像を行列として扱います。それぞれ順番に確認していきましょう。
スカラー
スカラーは1や2のように単なる数値を意味します。(0階層のTensor)
x = 1
ベクトル
ベクトルはスカラーの集合であり、プログラムにおける配列と同じです。(1階層のTensor)
x = [1, 2, 3]
行列
行列はベクトルの集合であり、多次元の配列となります。(2階層のTensor)
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
Tensorの生成
まずはtorch.tensor()
を利用してpythonのリストからTensorを生成する方法を確認します。PyTorchをimportするにはimport torch
とします。
import torch
data = [1, 2, 3]
data = torch.tensor(data)
print(data)
Tensorの使い方はNumPyのndarrayと非常に似ています。また、Tensorをndarrayに変換したり、逆にndarrayをTensorに変換することも可能です。以下はndarrayからTensorに変換する方法です。
np_x = np.arange(0, 10)
print(type(np_x))
ts_x = torch.tensor(np_x)
print(ts_x)
# <class 'numpy.ndarray'>
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Tensorの操作
サイズの確認
Tensorのサイズを確認するにはsize()
を利用します。以下の例は2行3列の行列となります。
data = [[1, 2, 3],[4, 5, 6]]
data = torch.tensor(data)
print(data.size())
# torch.Size([2, 3])
単一の値を取得するにはitem()
を利用します。
print(ts_data[0][0].item())
Tensorの四則演算
続いて基本的なテンソルの演算方法です。
data1 = torch.tensor([[10, 20, 30],[40, 50, 60]])
data2 = torch.tensor([[1, 2, 3],[4, 5, 6]])
# テンソル同士の足し算
print(data1 + data2)
# テンソル同士の引き算
print(data1 - data2)
# テンソル同士の掛け算
print(data1 * data2)
# テンソル同士の割り算
print(data1 // data2)
# tensor([[11, 22, 33],
# [44, 55, 66]])
# tensor([[ 9, 18, 27],
# [36, 45, 54]])
# tensor([[ 10, 40, 90],
# [160, 250, 360]])
# tensor([[10, 10, 10],
# [10, 10, 10]])
テンソルに対しての数学関数の適用
テンソルに対して数学関数を適用するには、以下のように実行します。数学関数を適用するためには値がfloat
型である必要があるので’dtype’にfloat
を指定します。
data = torch.tensor([[1, 2, 3],[4, 5, 6]], dtype=float)
#平均値
print(torch.mean(data))
# 絶対値
print(torch.abs(data))
# 標準偏差
print(torch.std(data))
# tensor(3.5000, dtype=torch.float64)
# tensor([[1., 2., 3.],
[4., 5., 6.]], dtype=torch.float64)
# tensor(1.8708, dtype=torch.float64)
Tensorの変形
画像処理や言語解析の際に、テンソルの次元を変更することがよくあります。 次元の変更には.view()
を用い、引数に変更後の次元を指定します。以下の例では1×9から3×3へ変形しています。
data = [1, 2, 3, 4, 5, 6, 7, 8, 9]
data = torch.tensor(data)
print(data.view(3, 3))
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
まとめ
今回はTensorの基本的な説明と操作方法についてご紹介しました。PyTorchを扱う上でTensorの理解と操作は必須となりますので、繰り返し確認するようにしてください。