このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
train
浅層ニューラル ネットワークの学習
構文
説明
この関数は浅層ニューラル ネットワークに学習させます。畳み込みニューラル ネットワークまたは LSTM ニューラル ネットワークを使用した深層学習の場合、代わりに trainnet
または trainNetwork
を使用してください。
例
ネットワークの学習およびプロット
以下では、入力 x
およびターゲット t
によって、プロットできる簡単な関数が定義されます。
x = [0 1 2 3 4 5 6 7 8];
t = [0 0.84 0.91 0.14 -0.77 -0.96 -0.28 0.66 0.99];
plot(x,t,'o')
ここでは、feedforwardnet
によって、2 層フィードフォワード ネットワークが作成されます。このネットワークには、10 個のニューロンがある 1 つの隠れ層があります。
net = feedforwardnet(10); net = configure(net,x,t); y1 = net(x) plot(x,t,'o',x,y1,'x')
ネットワークに学習させ、再度シミュレーションを行います。
net = train(net,x,t); y2 = net(x) plot(x,t,'o',x,y1,'x',x,y2,'*')
NARX 時系列ネットワークの学習
この例では、制御電流 x
および磁石の垂直位置応答 t
によって定義される磁気浮上システムをモデル化するように、外部入力を伴う開ループの非線形自己回帰ネットワークに学習をさせた後、ネットワークのシミュレーションを行います。学習およびシミュレーションを行う前に、関数 preparets
によってデータを準備します。これによって、開ループ ネットワークの結合された入力 xo
が作成されます。これには、外部入力 x
と位置 t
の前の値の両方が含まれます。さらに、遅延状態 xi
も準備されます。
[x,t] = maglev_dataset; net = narxnet(10); [xo,xi,~,to] = preparets(net,x,{},t); net = train(net,xo,to,xi); y = net(xo,xi)
同じシステムのシミュレーションを、閉ループ形式で行うこともできます。
netc = closeloop(net); view(netc) [xc,xi,ai,tc] = preparets(netc,x,{},t); yc = netc(xc,xi,ai);
並列プールでのネットワークの並列学習
学習時に使用されていないハードウェア リソースが存在する場合、並列で学習させることで、より高速にネットワークに学習させ、そのようなリソースが存在しなければメモリに収まらないようなデータセットを使用できる可能性があります。ネットワークに並列学習させるには、Parallel Computing Toolbox™ が必要です。これは、逆伝播学習でのみサポートされており、自己組織化マップではサポートされていません。
以下では、学習およびシミュレーションは、並列 MATLAB ワーカー全体で行われます。
[X,T] = vinyl_dataset; net = feedforwardnet(10); net = train(net,X,T,'useParallel','yes','showResources','yes'); Y = net(X);
Composite 値を使用してデータを手動で分散させ、結果を Composite 値として取得します。データが分散された状態で読み込まれる際に、各データセットが RAM に収まらなければならない場合、データセット全体はすべてのワーカーの合計 RAM によってのみ制限されます。
[X,T] = vinyl_dataset; Q = size(X,2); Xc = Composite; Tc = Composite; numWorkers = numel(Xc); ind = [0 ceil((1:numWorkers)*(Q/numWorkers))]; for i=1:numWorkers indi = (ind(i)+1):ind(i+1); Xc{i} = X(:,indi); Tc{i} = T(:,indi); end net = feedforwardnet; net = configure(net,X,T); net = train(net,Xc,Tc); Yc = net(Xc);
上記の例では、関数 configure
を使用して、ネットワークの入力の次元および処理の設定を行っていることに注意してください。これは通常、train が呼び出されたときに自動的に行われますが、Composite データを指定する場合は、Composite 以外のデータを使用してこのステップを手動で行わなければなりません。
GPU でのネットワークの学習
Parallel Computing Toolbox によってサポートされている場合、現在の GPU デバイスを使用したネットワークの学習が可能です。GPU での学習は現在、逆伝播学習でのみサポートされており、自己組織化マップではサポートされていません。
[X,T] = vinyl_dataset; net = feedforwardnet(10); net = train(net,X,T,'useGPU','yes'); y = net(X);
データを GPU に手動で配置するには、以下のようにします。
[X,T] = vinyl_dataset; Xgpu = gpuArray(X); Tgpu = gpuArray(T); net = configure(net,X,T); net = train(net,Xgpu,Tgpu); Ygpu = net(Xgpu); Y = gather(Ygpu);
上記の例では、関数 configure を使用して、ネットワークの入力の次元および処理の設定を行っていることに注意してください。これは通常、train が呼び出されたときに自動的に行われますが、gpuArray データを指定する場合は、gpuArray 以外のデータを使用してこのステップを手動で行わなければなりません。
ワーカーがそれぞれ異なる固有の GPU に割り当てられており、その他のワーカーが CPU で実行されている状態で並列実行するには、以下のようにします。
net = train(net,X,T,'useParallel','yes','useGPU','yes'); y = net(X);
CPU ワーカーは同等の速度を実現できないため、固有の GPU を持つワーカーのみを使用する方が高速になります。
net = train(net,X,T,'useParallel','yes','useGPU','only'); Y = net(X);
チェックポイントの保存を使用したネットワークの学習
以下では、2 分間に 1 回を超えないペースで保存されるチェックポイントを使用して、ネットワークに学習させます。
[x,t] = vinyl_dataset; net = fitnet([60 30]); net = train(net,x,t,'CheckpointFile','MyCheckpoint','CheckpointDelay',120);
コンピューターの障害の発生後に、最新のネットワークを復元し、これを使用して障害発生時点から学習を継続できます。チェックポイント ファイルには、構造体変数 checkpoint
が含まれます。これには、ネットワーク、学習記録、ファイル名、時間、および数値が含まれます。
[x,t] = vinyl_dataset; load MyCheckpoint net = checkpoint.net; net = train(net,x,t,'CheckpointFile','MyCheckpoint');
入力引数
net
— 入力ネットワーク
network
オブジェクト
入力ネットワーク。network
オブジェクトとして指定します。network
オブジェクトを作成するには、feedforwardnet
、narxnet
などを使用します。
X
— ネットワークの入力
行列 | cell 配列 | Composite データ | gpuArray
ネットワーク入力。R
行 Q
列の行列または Ni
行 TS
列の cell 配列として指定します。ここで、
R
は入力サイズQ
はバッチ サイズNi = net.numInputs
TS
はタイム ステップ数
引数 train
には、行列 (静的な問題、および単一の入出力があるネットワーク) および cell 配列 (複数のタイム ステップ、および複数の入出力があるネットワーク) の 2 つの形式があります。
行列形式は、1 タイム ステップのみのシミュレーションが行われる (
TS = 1
) 場合に使用できます。これは入出力が 1 つしかないネットワークの場合に便利ですが、複数の入出力があるネットワークにも使用できます。ネットワークに複数の入力がある場合、行列のサイズは (Ri
の合計) 行Q
列になります。cell 配列形式はより一般的で、複数の入出力があるネットワークの場合、入力をシーケンスで与えることができてより便利です。各要素
X{i,ts}
は、Ri
行Q
列の行列です。ここでRi = net.inputs{i}.size
です。
Composite データが使用される場合、'useParallel'
が自動的に 'yes'
に設定されます。関数は Composite データを取り、Composite の結果を返します。
gpuArray データが使用される場合、'useGPU'
が自動的に 'yes'
に設定されます。関数は gpuArray データを取り、gpuArray の結果を返します。
メモ
X の列に 1 つ以上の NaN
が含まれる場合、train
はこの列を学習、テスト、または検証に使用しません。
T
— ネットワークのターゲット
ゼロ配列またはゼロ行列 (既定値) | 行列 | cell 配列 | Composite データ | gpuArray
ネットワークのターゲット。U
行 Q
列の行列または No
行 TS
列の cell 配列として指定します。ここで、
U
は出力サイズQ
はバッチ サイズNo = net.numOutputs
TS
はタイム ステップ数
引数 train
には、行列 (静的な問題、および単一の入出力があるネットワーク) および cell 配列 (複数のタイム ステップ、および複数の入出力があるネットワーク) の 2 つの形式があります。
行列形式は、1 タイム ステップのみのシミュレーションが行われる (
TS = 1
) 場合に使用できます。これは入出力が 1 つしかないネットワークの場合に便利ですが、複数の入出力があるネットワークにも使用できます。ネットワークに複数の入力がある場合、行列のサイズは (Ui
の合計) 行Q
列になります。cell 配列形式はより一般的で、複数の入出力があるネットワークの場合、入力をシーケンスで与えることができてより便利です。各要素
T{i,ts}
は、Ui
行Q
列の行列です。ここでUi = net.outputs{i}.size
です。
Composite データが使用される場合、'useParallel'
が自動的に 'yes'
に設定されます。関数は Composite データを取り、Composite の結果を返します。
gpuArray データが使用される場合、'useGPU'
が自動的に 'yes'
に設定されます。関数は gpuArray データを取り、gpuArray の結果を返します。
T
はオプションであり、ターゲットを必要とするネットワークにのみ使用する必要があることに注意してください。
メモ
ターゲット T
にある NaN
値はすべて、欠損データとして扱われます。T
の要素が NaN
である場合、その要素は学習、テスト、または検証に使用されません。
Xi
— 入力遅延の初期条件
ゼロ配列またはゼロ行列 (既定値) | cell 配列 | 行列
入力遅延の初期条件。Ni
行 ID
列の cell 配列または R
行 (ID*Q)
列の行列として指定します。ここで、
ID = net.numInputDelays
Ni = net.numInputs
R
は入力サイズQ
はバッチ サイズ
cell 配列入力の場合、Xi
の列は、最も古い遅延条件から最新の遅延条件まで順に並べられます。Xi{i,k}
は時間 ts = k - ID
での入力 i
です。
Xi
もオプションであり、入力または層の遅延があるネットワークにのみ使用する必要があります。
Ai
— 層遅延の初期条件
ゼロ配列またはゼロ行列 (既定値) | cell 配列 | 行列
層遅延の初期条件。Nl
行 LD
列の cell 配列または (Si
の合計) 行 LD*Q
列の行列として指定します。ここで、
Nl = net.numLayers
LD = net.numLayerDelays
Si = net.layers{i}.size
Q
はバッチ サイズ
cell 配列入力の場合、Ai
の列は、最も古い遅延条件から最新の遅延条件まで順に並べられます。Ai{i,k}
は時間 ts = k - LD
での層出力 i
です。
EW
— 誤差の重み
cell 配列
誤差の重み。No
行 TS
列の cell 配列または (Ui
の合計) 行 Q
列の行列として指定します。ここで、
No = net.numOutputs
TS
はタイム ステップ数Ui = net.outputs{i}.size
Q
はバッチ サイズ
cell 配列入力の場合、各要素 EW{i,ts}
は Ui
行 Q
列の行列です。ここで、
Ui = net.outputs{i}.size
Q
はバッチ サイズ
誤差の重み EW
は、No
、TS
、Ui
、または Q
のすべてまたはいずれかの代わりに、サイズ 1 になることもあります。この場合、EW
は、ターゲット T
と一致するように次元が自動的に拡張されます。これによって、(標本単位など) 任意の次元で簡単に重要度に重みを付けることができると同時に、(TS=1
の場合の時間など) 別の次元でも同じ重要度を使用できます。すべての次元が 1 の場合 (EW = {1}
の場合など)、すべてのターゲット値が同じ重要度で扱われます。これは、EW
の既定値です。
上述したように、誤差の重み EW
はターゲット T
と同じ次元にするか、一部の次元を 1 に設定することができます。たとえば、EW
が 1 行 Q
列の場合、ターゲット標本の重要度はそれぞれ異なりますが、標本内の各要素の重要度は同じになります。EW
が (Ui
の合計) 行 1 列の場合、各出力要素の重要度は異なりますが、すべての標本が同じ重要度で扱われます。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで、Name
は引数名で、Value
は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
R2021a より前では、コンマを使用して名前と値の各ペアを区切り、Name
を引用符で囲みます。
例: 'useParallel','yes'
useParallel
— 並列計算を指定するオプション
'no'
(既定値) | 'yes'
並列計算を指定するオプション。'yes'
または 'no'
として指定します。
'no'
– 通常の MATLAB スレッドで計算が実行されます。これは、'useParallel'
の既定の設定です。'yes'
– 並列プールが開いている場合、計算は並列ワーカーで実行されます。開いている並列プールがない場合、既定のクラスター プロファイルを使用して 1 つのプールが起動されます。
useGPU
— GPU 計算を指定するオプション
'no'
(既定値) | 'yes'
| 'only'
GPU 計算を指定するオプション。'yes'
、'no'
、または 'only'
として指定します。
'no'
– CPU で計算が実行されます。これは、'useGPU'
の既定の設定です。'yes'
– 現在のgpuDevice
がサポートされている GPU の場合 (Parallel Computing Toolbox の GPU 要件を参照)、計算はこの gpuDevice で実行されます。現在のgpuDevice
がサポートされていない場合、計算は CPU で実行されます。'useParallel'
も'yes'
であり、並列プールが開いている場合、固有の GPU を持つ各ワーカーはその GPU を使用し、他のワーカーはそれぞれの CPU コアで計算を実行します。'only'
– どの並列プールも開いていない場合、この設定は'yes'
と同じになります。並列プールが開いている場合、固有の GPU を持つワーカーのみが使用されます。ただし、並列プールが開いていても、サポートされている GPU が利用できない場合、計算は全ワーカー CPU での実行に戻ります。
showResources
— リソースを表示するオプション
'no'
(既定値) | 'yes'
リソースを表示するオプション。'yes'
または 'no'
として指定します。
'no'
– 使用されるコンピューティング リソースをコマンド ラインに表示しません。これは既定の設定です。'yes'
– 実際に使用されるコンピューティング リソースの概要をコマンド ラインに表示します。並列計算または GPU コンピューティングが要求されても、並列プールが開いていないか、サポートされている GPU が利用できない場合、実際のリソースが要求されたリソースと異なる場合があります。並列ワーカーが使用される場合、使用されていないプール内のワーカーを含む、各ワーカーの計算モードが示されます。
reduction
— メモリ削減
1 (既定値) | 正の整数
メモリ削減。正の整数として指定します。
ほとんどのニューラル ネットワークでは、既定の CPU 学習計算モードはコンパイルされた MEX アルゴリズムになります。ただし、大規模ネットワークでは、MATLAB® 計算モードを使用して計算が行われる場合もあります。これは、'showResources'
を使用して確認できます。MATLAB が使用されており、メモリが問題である場合は、削減オプションの値 N を 1 より大きい値に設定すると、学習時間が長くなる代わりに、N の係数で学習に必要な一時ストレージの量が減ります。
CheckpointFile
— チェックポイント ファイル
''
(既定値) | 文字ベクトル
チェックポイント ファイル。文字ベクトルとして指定します。
'CheckpointFile'
の値は、現在の作業フォルダーに保存するファイル名または別のフォルダーのファイル パスに設定するか、または空の string に設定してチェックポイントの保存を無効にする (既定値) ことができます。
CheckpointDelay
— チェックポイント遅延
60 (既定値) | 非負の整数
チェックポイント遅延。非負の整数として指定します。
オプションのパラメーター 'CheckpointDelay'
は、保存頻度を制限します。チェックポイントの頻度を制限することで、チェックポイントの保存にかかる時間が計算にかかる時間より短く抑えられ、効率を向上できます。既定値は 60 です。これは、チェックポイントの保存が 1 分間に 1 回以上行われないことを意味します。チェックポイントの保存が各エポックで 1 回のみ行われるようにするには、'CheckpointDelay'
の値を 0 に設定します。
出力引数
trainedNet
— 学習済みネットワーク
network
オブジェクト
学習済みネットワーク。network
オブジェクトとして返されます。
tr
— 学習記録
構造体
学習記録 (epoch
および perf
)。フィールドがネットワーク学習関数 (net.NET.trainFcn
) によって異なる構造体として返されます。含まれるフィールドには以下のものがあります。
学習、データ分割、性能の関数およびパラメーター
学習セット、検証セット、およびテスト セットのデータ分割インデックス
学習セット、検証セット、およびテスト セットのデータ分割マスク
エポックの数 (
num_epochs
) および最適なエポック (best_epoch
)学習の状態名の一覧 (
states
)学習全体を通じて値を記録する各状態名のフィールド
各エポックにおいて評価されたネットワークの最高性能: 学習セットでの最高性能 (
best_perf
)、検証セットでの最高性能 (best_vperf
)、およびテスト セットでの最高性能 (best_tperf
)
アルゴリズム
train
は、net.trainParam
が示す学習パラメーター値を使用して、net.trainFcn
が示す関数を呼び出します。
通常、1 学習エポックとは、ネットワークへの全入力ベクトルの 1 回の提供と定義されます。その後、そういった提供すべての結果に従ってネットワークが更新されます。
学習は、エポックの回数が最大数に達するか、性能目標が達成されるか、または関数 net.trainFcn
のその他の停止条件が発生するまで行われます。
一部の学習関数には、この基準とは異なり、エポックごとに 1 つの入力ベクトル (またはシーケンス) のみを提供するものがあります。各エポックで同時入力ベクトル (またはシーケンス) から入力ベクトル (またはシーケンス) がランダムに選択されます。competlayer
は、trainru
(これを行う学習関数) を使用するネットワークを返します。
バージョン履歴
R2006a より前に導入
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)