今天就跟大家聊聊有關(guān)dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別,可能很多人都不太了解,為了讓大家更加了解,小編給大家總結(jié)了以下內(nèi)容,希望大家根據(jù)這篇文章可以有所收獲。
企業(yè)建站必須是能夠以充分展現(xiàn)企業(yè)形象為主要目的,是企業(yè)文化與產(chǎn)品對(duì)外擴(kuò)展宣傳的重要窗口,一個(gè)合格的網(wǎng)站不僅僅能為公司帶來(lái)巨大的互聯(lián)網(wǎng)上的收集和信息發(fā)布平臺(tái),創(chuàng)新互聯(lián)面向各種領(lǐng)域:成都宣傳片制作等成都網(wǎng)站設(shè)計(jì)、全網(wǎng)營(yíng)銷(xiāo)推廣解決方案、網(wǎng)站設(shè)計(jì)等建站排名服務(wù)。
實(shí)現(xiàn)步驟
1.隨機(jī)初始化若干個(gè)智能體(神經(jīng)網(wǎng)絡(luò)),并讓智能體識(shí)別訓(xùn)練數(shù)據(jù),并對(duì)識(shí)別結(jié)果進(jìn)行排序
2.隨機(jī)在排序結(jié)果中選擇一個(gè)作為母本,并在比母本識(shí)別率更高的智能體中隨機(jī)選擇一個(gè)作為父本
3.隨機(jī)選擇母本或父本同位的神經(jīng)網(wǎng)絡(luò)超參組成新的智能體
4.按照母本的排序?qū)χ悄荏w進(jìn)行超參調(diào)整,排序越靠后調(diào)整幅度越大(1%~10%)之間
5.讓新的智能體識(shí)別訓(xùn)練集并放入排行榜,并移除排行榜最后一位
6.重復(fù)2~5過(guò)程,讓識(shí)別率越來(lái)越高
這個(gè)過(guò)程就類(lèi)似于自然界的優(yōu)勝劣汰,將神經(jīng)網(wǎng)絡(luò)超參看作dna,超參的調(diào)整看作dna的突變;當(dāng)然還可以把擁有不同隱藏層的神經(jīng)網(wǎng)絡(luò)看作不同的物種,讓競(jìng)爭(zhēng)過(guò)程更加多樣化.當(dāng)然我們這里只討論一種神經(jīng)網(wǎng)絡(luò)的情況
優(yōu)勢(shì): 可以解決很多沒(méi)有頭緒的問(wèn)題 劣勢(shì): 訓(xùn)練效率極低
gitee地址:
https://gitee.com/ichiva/gnn.git
實(shí)現(xiàn)步驟 1.進(jìn)化接口
public interface Evolution {
/**
* 遺傳
* @param mDna
* @param fDna
* @return
*/
INDArray inheritance(INDArray mDna,INDArray fDna);
/**
* 突變
* @param dna
* @param v
* @param r 突變范圍
* @return
*/
INDArray mutation(INDArray dna,double v, double r);
/**
* 置換
* @param dna
* @param v
* @return
*/
INDArray substitution(INDArray dna,double v);
/**
* 外源
* @param dna
* @param v
* @return
*/
INDArray other(INDArray dna,double v);
/**
* DNA 是否同源
* @param mDna
* @param fDna
* @return
*/
boolean iSogeny(INDArray mDna, INDArray fDna);
}一個(gè)比較通用的實(shí)現(xiàn)
public class MnistEvolution implements Evolution {
private static final MnistEvolution instance = new MnistEvolution();
public static MnistEvolution getInstance() {
return instance;
}
@Override
public INDArray inheritance(INDArray mDna, INDArray fDna) {
if(mDna == fDna) return mDna;
long[] mShape = mDna.shape();
if(!iSogeny(mDna,fDna)){
throw new RuntimeException("非同源dna");
}
INDArray nDna = Nd4j.create(mShape);
NdIndexIterator it = new NdIndexIterator(mShape);
while (it.hasNext()){
long[] next = it.next();
double val;
if(Math.random() > 0.5){
val = fDna.getDouble(next);
}else {
val = mDna.getDouble(next);
}
nDna.putScalar(next,val);
}
return nDna;
}
@Override
public INDArray mutation(INDArray dna, double v, double r) {
long[] shape = dna.shape();
INDArray nDna = Nd4j.create(shape);
NdIndexIterator it = new NdIndexIterator(shape);
while (it.hasNext()) {
long[] next = it.next();
if(Math.random() < v){
dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2));
}else {
nDna.putScalar(next,dna.getDouble(next));
}
}
return nDna;
}
@Override
public INDArray substitution(INDArray dna, double v) {
long[] shape = dna.shape();
INDArray nDna = Nd4j.create(shape);
NdIndexIterator it = new NdIndexIterator(shape);
while (it.hasNext()) {
long[] next = it.next();
if(Math.random() > v){
long[] tag = new long[shape.length];
for (int i = 0; i < shape.length; i++) {
tag[i] = (long) (Math.random() * shape[i]);
}
nDna.putScalar(next,dna.getDouble(tag));
}else {
nDna.putScalar(next,dna.getDouble(next));
}
}
return nDna;
}
@Override
public INDArray other(INDArray dna, double v) {
long[] shape = dna.shape();
INDArray nDna = Nd4j.create(shape);
NdIndexIterator it = new NdIndexIterator(shape);
while (it.hasNext()) {
long[] next = it.next();
if(Math.random() > v){
nDna.putScalar(next,Math.random());
}else {
nDna.putScalar(next,dna.getDouble(next));
}
}
return nDna;
}
@Override
public boolean iSogeny(INDArray mDna, INDArray fDna) {
long[] mShape = mDna.shape();
long[] fShape = fDna.shape();
if (mShape.length == fShape.length) {
for (int i = 0; i < mShape.length; i++) {
if (mShape[i] != fShape[i]) {
return false;
}
}
return true;
}
return false;
}
}定義智能體配置接口
public interface AgentConfig {
/**
* 輸入量
* @return
*/
int getInput();
/**
* 輸出量
* @return
*/
int getOutput();
/**
* 神經(jīng)網(wǎng)絡(luò)配置
* @return
*/
MultiLayerConfiguration getMultiLayerConfiguration();
}按手寫(xiě)數(shù)字識(shí)別進(jìn)行配置實(shí)現(xiàn)
public class MnistConfig implements AgentConfig {
@Override
public int getInput() {
return 28 * 28;
}
@Override
public int getOutput() {
return 10;
}
@Override
public MultiLayerConfiguration getMultiLayerConfiguration() {
return new NeuralNetConfiguration.Builder()
.seed((long) (Math.random() * Long.MAX_VALUE))
.updater(new Nesterovs(0.006, 0.9))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(getInput())
.nOut(1000)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
.nIn(1000)
.nOut(getOutput())
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true)
.build();
}
}智能體基類(lèi)
@Getter
public class Agent {
private final AgentConfig config;
private final INDArray dna;
private final MultiLayerNetwork multiLayerNetwork;
/**
* 采用默認(rèn)方法初始化參數(shù)
* @param config
*/
public Agent(AgentConfig config){
this(config,null);
}
/**
*
* @param config
* @param dna
*/
public Agent(AgentConfig config, INDArray dna){
if(dna == null){
this.config = config;
MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
this.multiLayerNetwork = new MultiLayerNetwork(conf);
multiLayerNetwork.init();
this.dna = multiLayerNetwork.params();
}else {
this.config = config;
MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
this.multiLayerNetwork = new MultiLayerNetwork(conf);
multiLayerNetwork.init(dna,true);
this.dna = dna;
}
}
}手寫(xiě)數(shù)字智能體實(shí)現(xiàn)類(lèi)
@Getter
@Setter
public class MnistAgent extends Agent {
private static final AtomicInteger index = new AtomicInteger(0);
private String name;
/**
* 環(huán)境適應(yīng)分?jǐn)?shù)
*/
private double score;
/**
* 驗(yàn)證分?jǐn)?shù)
*/
private double validScore;
public MnistAgent(AgentConfig config) {
this(config,null);
}
public MnistAgent(AgentConfig config, INDArray dna) {
super(config, dna);
name = "agent-" + index.incrementAndGet();
}
public static MnistConfig mnistConfig = new MnistConfig();
public static MnistAgent newInstance(){
return new MnistAgent(mnistConfig);
}
public static MnistAgent create(INDArray dna){
return new MnistAgent(mnistConfig,dna);
}
}手寫(xiě)數(shù)字識(shí)別環(huán)境構(gòu)建
@Slf4j
public class MnistEnv {
/**
* 環(huán)境數(shù)據(jù)
*/
private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> {
try {
return new MnistDataSetIterator(128, true, 0);
} catch (IOException e) {
throw new RuntimeException("mnist 文件讀取失敗");
}
});
private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> {
try {
return new MnistDataSetIterator(128, false, 0);
} catch (IOException e) {
throw new RuntimeException("mnist 文件讀取失敗");
}
});
private static final MnistEvolution evolution = MnistEvolution.getInstance();
/**
* 環(huán)境承載上限
*
* 超過(guò)上限AI會(huì)進(jìn)行激烈競(jìng)爭(zhēng)
*/
private final int max;
private Double maxScore,minScore;
/**
* 環(huán)境中的生命體
*
* 新生代與歷史代共同排序,選出最適應(yīng)環(huán)境的個(gè)體
*/
//2個(gè)變量,一個(gè)隊(duì)列保存KEY的順序,一個(gè)MAP保存KEY對(duì)應(yīng)的具體對(duì)象的數(shù)據(jù) 線程安全map
private final TreeMap<Double,MnistAgent> lives = new TreeMap<>();
/**
* 初始化環(huán)境
*
* 1.向環(huán)境中初始化ai
* 2.將初始化ai進(jìn)行環(huán)境適應(yīng)性測(cè)試,并排序
* @param max
*/
public MnistEnv(int max){
this.max = max;
for (int i = 0; i < max; i++) {
MnistAgent agent = MnistAgent.newInstance();
test(agent);
synchronized (lives) {
lives.put(agent.getScore(),agent);
}
log.info("初始化智能體 name = {} , score = {}",i,agent.getScore());
}
synchronized (lives) {
minScore = lives.firstKey();
maxScore = lives.lastKey();
}
}
/**
* 環(huán)境適應(yīng)性評(píng)估
* @param ai
*/
public void test(MnistAgent ai){
MultiLayerNetwork network = ai.getMultiLayerNetwork();
MnistDataSetIterator dataIterator = tLocal.get();
Evaluation eval = new Evaluation(ai.getConfig().getOutput());
try {
while (dataIterator.hasNext()) {
DataSet data = dataIterator.next();
INDArray output = network.output(data.getFeatures(), false);
eval.eval(data.getLabels(),output);
}
}finally {
dataIterator.reset();
}
ai.setScore(eval.accuracy());
}
/**
* 遷移評(píng)估
*
* @param ai
*/
public void validation(MnistAgent ai){
MultiLayerNetwork network = ai.getMultiLayerNetwork();
MnistDataSetIterator dataIterator = testLocal.get();
Evaluation eval = new Evaluation(ai.getConfig().getOutput());
try {
while (dataIterator.hasNext()) {
DataSet data = dataIterator.next();
INDArray output = network.output(data.getFeatures(), false);
eval.eval(data.getLabels(),output);
}
}finally {
dataIterator.reset();
}
ai.setValidScore(eval.accuracy());
}
/**
* 進(jìn)化
*
* 每輪隨機(jī)創(chuàng)建ai并放入環(huán)境中進(jìn)行優(yōu)勝劣汰
* @param n 進(jìn)化次數(shù)
*/
public void evolution(int n){
BlockThreadPool blockThreadPool=new BlockThreadPool(2);
for (int i = 0; i < n; i++) {
blockThreadPool.execute(() -> contend(newLive()));
}
// for (int i = 0; i < n; i++) {
// contend(newLive());
// }
}
/**
* 競(jìng)爭(zhēng)
* @param ai
*/
public void contend(MnistAgent ai){
test(ai);
quality(ai);
double score = ai.getScore();
if(score <= minScore){
UI.put("無(wú)法生存",String.format("name = %s, score = %s", ai.getName(),ai.getScore()));
return;
}
Map.Entry<Double, MnistAgent> lastEntry;
synchronized (lives) {
lives.put(score,ai);
if (lives.size() > max) {
MnistAgent lastAI = lives.remove(lives.firstKey());
UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore()));
}
lastEntry = lives.lastEntry();
minScore = lives.firstKey();
}
Double lastScore = lastEntry.getKey();
if(lastScore > maxScore){
maxScore = lastScore;
MnistAgent agent = lastEntry.getValue();
validation(agent);
UI.put("max驗(yàn)證",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore()));
try {
Warehouse.write(agent);
} catch (IOException ex) {
log.error("保存對(duì)象失敗",ex);
}
}
}
ArrayList<Double> scoreList = new ArrayList<>(100);
ArrayList<Integer> avgList = new ArrayList<>();
private void quality(MnistAgent ai) {
synchronized (scoreList) {
scoreList.add(ai.getScore());
if (scoreList.size() >= 100) {
double avg = scoreList.stream().mapToDouble(e -> e)
.average().getAsDouble();
avgList.add((int) (avg * 1000));
StringBuffer buffer = new StringBuffer();
avgList.forEach(e -> buffer.append(e).append('\t'));
UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString()));
scoreList.clear();
}
}
}
/**
* 隨機(jī)生成新智能體
*
* 完全隨機(jī)產(chǎn)生母本
* 隨機(jī)從比目標(biāo)相同或更高評(píng)分中選擇父本
*
* 基因進(jìn)化在1%~10%之間進(jìn)行,評(píng)分越高基于越穩(wěn)定
*/
public MnistAgent newLive(){
double r = Math.random();
//基因突變率
double v = r / 11 + 0.01;
//母本
MnistAgent mAgent = getMother(r);
//父本
MnistAgent fAgent = getFather(r);
int i = (int) (Math.random() * 3);
INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna());
switch (i){
case 0:
newDNA = evolution.other(newDNA,v);
break;
case 1:
newDNA = evolution.mutation(newDNA,v,0.1);
break;
case 2:
newDNA = evolution.substitution(newDNA,v);
break;
}
return MnistAgent.create(newDNA);
}
/**
* 父本只選擇比母本評(píng)分高的樣本
* @param r
* @return
*/
private MnistAgent getFather(double r) {
r += (Math.random() * (1-r));
return getMother(r);
}
private MnistAgent getMother(double r) {
int index = (int) (r * max);
return getMnistAgent(index);
}
private MnistAgent getMnistAgent(int index) {
synchronized (lives) {
Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator();
for (int i = 0; i < index; i++) {
it.next();
}
return it.next().getValue();
}
}
}主函數(shù)
@Slf4j
public class Program {
public static void main(String[] args) {
UI.put("開(kāi)始時(shí)間",new Date().toLocaleString());
MnistEnv env = new MnistEnv(128);
env.evolution(Integer.MAX_VALUE);
}
}運(yùn)行截圖

看完上述內(nèi)容,你們對(duì)dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別有進(jìn)一步的了解嗎?如果還想了解更多知識(shí)或者相關(guān)內(nèi)容,請(qǐng)關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道,感謝大家的支持。
標(biāo)題名稱(chēng):dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別
分享網(wǎng)址:http://chinadenli.net/article8/jpsoip.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供營(yíng)銷(xiāo)型網(wǎng)站建設(shè)、軟件開(kāi)發(fā)、Google、、ChatGPT、品牌網(wǎng)站制作
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)