Apprenez le paramétrage d'une variété
Apprenez à paramétrer une variété le long de laquelle se trouvent les données d'entrée en utilisant un auto-encodeur, un réseau doté d'une couche de « goulot d'étranglement » qui apprend à reconstruire l'entrée d'origine.
Échantillonnez les données d'apprentissage d'une partie d'une variété bidimensionnelle synthétique.

manifold =
Table[AngleVector[{x, 0.9 Pi x}] +
x/20*RandomVariate[NormalDistribution[], 2], {x, 0, 1, 0.001}];
plot = ListPlot[manifold, PlotStyle -> Orange]

Créez un réseau avec une couche de « goulot d'étranglement » dans le but d'apprendre le paramétrage de la variété.

net = NetChain[{25, Ramp, 1, 25, Ramp, 2}, "Input" -> 2]

Créez un réseau de perte qui calcule une perte sur la base d'une « erreur de reconstruction », une mesure du degré auquel le réseau peut produire une sortie qui est identique à son entrée.

lossNet =
NetGraph[{net, MeanSquaredLossLayer[]}, {1 -> 2,
NetPort["Input"] -> NetPort[2, "Target"]}]

Entraînez le réseau de perte sur la variété et extrayez le réseau d'origine à partir du réseau de perte.

lossNet =
NetTrain[lossNet, <|"Input" -> manifold|>, BatchSize -> 4096];
trained = NetExtract[lossNet, 1];
Visualisez la façon dont le réseau projette des points arbitraires sur la variété.

{{xmin, xmax}, {ymin, ymax}} = CoordinateBounds[manifold, .2];
Show[plot,
StreamPlot[
trained[{x, y}] - {x, y}, {x, xmin, xmax}, {y, ymin, ymax}]]

Divisez le réseau en un réseau « encodeur » et un réseau « décodeur » (l'encodeur paramétrise les points en utilisant une valeur scalaire unique, alors que le décodeur reconstruit le point à partir de ce paramétrage).

decoder = Drop[trained, 3]
encoder = Take[trained, 3]


Coloriez chaque point dans la variété d'origine par son paramétrage sous l'encodeur.

ListPlot[Style[#, Hue[First[0.3 + encoder[#]]/3]] & /@ manifold]

Obtenez la plage de paramétrage en appliquant l'encodeur sur la variété.

{min, max} = MinMax[encoder[manifold]]

Affichez la reconstruction sur cette plage avec la variété d'origine.

Show[plot, ListLinePlot[Table[decoder[x], {x, min, max, .01}]]]
