欧美一区二区三区老妇人-欧美做爰猛烈大尺度电-99久久夜色精品国产亚洲a-亚洲福利视频一区二区

dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別

今天就跟大家聊聊有關(guān)dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別,可能很多人都不太了解,為了讓大家更加了解,小編給大家總結(jié)了以下內(nèi)容,希望大家根據(jù)這篇文章可以有所收獲。

企業(yè)建站必須是能夠以充分展現(xiàn)企業(yè)形象為主要目的,是企業(yè)文化與產(chǎn)品對(duì)外擴(kuò)展宣傳的重要窗口,一個(gè)合格的網(wǎng)站不僅僅能為公司帶來(lái)巨大的互聯(lián)網(wǎng)上的收集和信息發(fā)布平臺(tái),創(chuàng)新互聯(lián)面向各種領(lǐng)域:成都宣傳片制作成都網(wǎng)站設(shè)計(jì)全網(wǎng)營(yíng)銷(xiāo)推廣解決方案、網(wǎng)站設(shè)計(jì)等建站排名服務(wù)。


實(shí)現(xiàn)步驟

  • 1.隨機(jī)初始化若干個(gè)智能體(神經(jīng)網(wǎng)絡(luò)),并讓智能體識(shí)別訓(xùn)練數(shù)據(jù),并對(duì)識(shí)別結(jié)果進(jìn)行排序

  • 2.隨機(jī)在排序結(jié)果中選擇一個(gè)作為母本,并在比母本識(shí)別率更高的智能體中隨機(jī)選擇一個(gè)作為父本

  • 3.隨機(jī)選擇母本或父本同位的神經(jīng)網(wǎng)絡(luò)超參組成新的智能體

  • 4.按照母本的排序?qū)χ悄荏w進(jìn)行超參調(diào)整,排序越靠后調(diào)整幅度越大(1%~10%)之間

  • 5.讓新的智能體識(shí)別訓(xùn)練集并放入排行榜,并移除排行榜最后一位

  • 6.重復(fù)2~5過(guò)程,讓識(shí)別率越來(lái)越高

這個(gè)過(guò)程就類(lèi)似于自然界的優(yōu)勝劣汰,將神經(jīng)網(wǎng)絡(luò)超參看作dna,超參的調(diào)整看作dna的突變;當(dāng)然還可以把擁有不同隱藏層的神經(jīng)網(wǎng)絡(luò)看作不同的物種,讓競(jìng)爭(zhēng)過(guò)程更加多樣化.當(dāng)然我們這里只討論一種神經(jīng)網(wǎng)絡(luò)的情況

優(yōu)勢(shì): 可以解決很多沒(méi)有頭緒的問(wèn)題 劣勢(shì): 訓(xùn)練效率極低

gitee地址:

https://gitee.com/ichiva/gnn.git

實(shí)現(xiàn)步驟 1.進(jìn)化接口

public interface Evolution {

    /**
     * 遺傳
     * @param mDna
     * @param fDna
     * @return
     */
    INDArray inheritance(INDArray mDna,INDArray fDna);

    /**
     * 突變
     * @param dna
     * @param v
     * @param r 突變范圍
     * @return
     */
    INDArray mutation(INDArray dna,double v, double r);

    /**
     * 置換
     * @param dna
     * @param v
     * @return
     */
    INDArray substitution(INDArray dna,double v);

    /**
     * 外源
     * @param dna
     * @param v
     * @return
     */
    INDArray other(INDArray dna,double v);

    /**
     * DNA 是否同源
     * @param mDna
     * @param fDna
     * @return
     */
    boolean iSogeny(INDArray mDna, INDArray fDna);
}

一個(gè)比較通用的實(shí)現(xiàn)

public class MnistEvolution implements Evolution {

    private static final MnistEvolution instance = new MnistEvolution();

    public static MnistEvolution getInstance() {
        return instance;
    }

    @Override
    public INDArray inheritance(INDArray mDna, INDArray fDna) {
        if(mDna == fDna) return mDna;

        long[] mShape = mDna.shape();
        if(!iSogeny(mDna,fDna)){
            throw new RuntimeException("非同源dna");
        }

        INDArray nDna = Nd4j.create(mShape);
        NdIndexIterator it = new NdIndexIterator(mShape);
        while (it.hasNext()){
            long[] next = it.next();

            double val;
            if(Math.random() > 0.5){
                val = fDna.getDouble(next);
            }else {
                val = mDna.getDouble(next);
            }

            nDna.putScalar(next,val);
        }

        return nDna;
    }

    @Override
    public INDArray mutation(INDArray dna, double v, double r) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() < v){
                dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2));
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public INDArray substitution(INDArray dna, double v) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() > v){
                long[] tag = new long[shape.length];
                for (int i = 0; i < shape.length; i++) {
                    tag[i] = (long) (Math.random() * shape[i]);
                }

                nDna.putScalar(next,dna.getDouble(tag));
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public INDArray other(INDArray dna, double v) {
        long[] shape = dna.shape();
        INDArray nDna = Nd4j.create(shape);
        NdIndexIterator it = new NdIndexIterator(shape);
        while (it.hasNext()) {
            long[] next = it.next();
            if(Math.random() > v){
                nDna.putScalar(next,Math.random());
            }else {
                nDna.putScalar(next,dna.getDouble(next));
            }
        }

        return nDna;
    }

    @Override
    public boolean iSogeny(INDArray mDna, INDArray fDna) {
        long[] mShape = mDna.shape();
        long[] fShape = fDna.shape();

        if (mShape.length == fShape.length) {
            for (int i = 0; i < mShape.length; i++) {
                if (mShape[i] != fShape[i]) {
                    return false;
                }
            }

            return true;
        }

        return false;
    }
}

定義智能體配置接口

public interface AgentConfig {
    /**
     * 輸入量
     * @return
     */
    int getInput();

    /**
     * 輸出量
     * @return
     */
    int getOutput();

    /**
     * 神經(jīng)網(wǎng)絡(luò)配置
     * @return
     */
    MultiLayerConfiguration getMultiLayerConfiguration();
}

按手寫(xiě)數(shù)字識(shí)別進(jìn)行配置實(shí)現(xiàn)

public class MnistConfig implements AgentConfig {
    @Override
    public int getInput() {
        return 28 * 28;
    }

    @Override
    public int getOutput() {
        return 10;
    }

    @Override
    public MultiLayerConfiguration getMultiLayerConfiguration() {
        return new NeuralNetConfiguration.Builder()
                .seed((long) (Math.random() * Long.MAX_VALUE))
                .updater(new Nesterovs(0.006, 0.9))
                .l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(getInput())
                        .nOut(1000)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
                        .nIn(1000)
                        .nOut(getOutput())
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .pretrain(false).backprop(true)
                .build();
    }
}

智能體基類(lèi)

@Getter
public class Agent {

    private final AgentConfig config;
    private final INDArray dna;
    private final MultiLayerNetwork multiLayerNetwork;

    /**
     * 采用默認(rèn)方法初始化參數(shù)
     * @param config
     */
    public Agent(AgentConfig config){
        this(config,null);
    }


    /**
     *
     * @param config
     * @param dna
     */
    public Agent(AgentConfig config, INDArray dna){
        if(dna == null){
            this.config = config;
            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
            this.multiLayerNetwork = new MultiLayerNetwork(conf);
            multiLayerNetwork.init();
            this.dna = multiLayerNetwork.params();
        }else {
            this.config = config;
            MultiLayerConfiguration conf = config.getMultiLayerConfiguration();
            this.multiLayerNetwork = new MultiLayerNetwork(conf);
            multiLayerNetwork.init(dna,true);
            this.dna = dna;
        }
    }
}

手寫(xiě)數(shù)字智能體實(shí)現(xiàn)類(lèi)

@Getter
@Setter
public class MnistAgent extends Agent {

    private static final AtomicInteger index = new AtomicInteger(0);

    private String name;

    /**
     * 環(huán)境適應(yīng)分?jǐn)?shù)
     */
    private double score;

    /**
     * 驗(yàn)證分?jǐn)?shù)
     */
    private double validScore;

    public MnistAgent(AgentConfig config) {
        this(config,null);
    }

    public MnistAgent(AgentConfig config, INDArray dna) {
        super(config, dna);
        name = "agent-" + index.incrementAndGet();
    }

    public static MnistConfig mnistConfig = new MnistConfig();

    public static MnistAgent newInstance(){
        return new MnistAgent(mnistConfig);
    }

    public static MnistAgent create(INDArray dna){
        return new MnistAgent(mnistConfig,dna);
    }
}

手寫(xiě)數(shù)字識(shí)別環(huán)境構(gòu)建

@Slf4j
public class MnistEnv {
    /**
     * 環(huán)境數(shù)據(jù)
     */
    private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> {
        try {
            return new MnistDataSetIterator(128, true, 0);
        } catch (IOException e) {
            throw new RuntimeException("mnist 文件讀取失敗");
        }
    });

    private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> {
        try {
            return new MnistDataSetIterator(128, false, 0);
        } catch (IOException e) {
            throw new RuntimeException("mnist 文件讀取失敗");
        }
    });

    private static final MnistEvolution evolution = MnistEvolution.getInstance();

    /**
     * 環(huán)境承載上限
     *
     * 超過(guò)上限AI會(huì)進(jìn)行激烈競(jìng)爭(zhēng)
     */
    private final int max;
    private Double maxScore,minScore;

    /**
     * 環(huán)境中的生命體
     *
     * 新生代與歷史代共同排序,選出最適應(yīng)環(huán)境的個(gè)體
     */
    //2個(gè)變量,一個(gè)隊(duì)列保存KEY的順序,一個(gè)MAP保存KEY對(duì)應(yīng)的具體對(duì)象的數(shù)據(jù)  線程安全map
    private final TreeMap<Double,MnistAgent> lives = new TreeMap<>();

    /**
     * 初始化環(huán)境
     *
     * 1.向環(huán)境中初始化ai
     * 2.將初始化ai進(jìn)行環(huán)境適應(yīng)性測(cè)試,并排序
     * @param max
     */
    public MnistEnv(int max){
        this.max = max;

        for (int i = 0; i < max; i++) {
            MnistAgent agent = MnistAgent.newInstance();
            test(agent);
            synchronized (lives) {
                lives.put(agent.getScore(),agent);
            }
            log.info("初始化智能體 name = {} , score = {}",i,agent.getScore());
        }



        synchronized (lives) {
            minScore = lives.firstKey();
            maxScore = lives.lastKey();
        }
    }

    /**
     * 環(huán)境適應(yīng)性評(píng)估
     * @param ai
     */
    public void test(MnistAgent ai){
        MultiLayerNetwork network = ai.getMultiLayerNetwork();
        MnistDataSetIterator dataIterator = tLocal.get();
        Evaluation eval = new Evaluation(ai.getConfig().getOutput());
        try {
            while (dataIterator.hasNext()) {
                DataSet data = dataIterator.next();
                INDArray output = network.output(data.getFeatures(), false);
                eval.eval(data.getLabels(),output);
            }
        }finally {
            dataIterator.reset();
        }
        ai.setScore(eval.accuracy());
    }

    /**
     * 遷移評(píng)估
     *
     * @param ai
     */
    public void validation(MnistAgent ai){
        MultiLayerNetwork network = ai.getMultiLayerNetwork();
        MnistDataSetIterator dataIterator = testLocal.get();
        Evaluation eval = new Evaluation(ai.getConfig().getOutput());
        try {
            while (dataIterator.hasNext()) {
                DataSet data = dataIterator.next();
                INDArray output = network.output(data.getFeatures(), false);
                eval.eval(data.getLabels(),output);
            }
        }finally {
            dataIterator.reset();
        }
        ai.setValidScore(eval.accuracy());
    }


    /**
     * 進(jìn)化
     *
     * 每輪隨機(jī)創(chuàng)建ai并放入環(huán)境中進(jìn)行優(yōu)勝劣汰
     * @param n 進(jìn)化次數(shù)
     */
    public void evolution(int n){
        BlockThreadPool blockThreadPool=new BlockThreadPool(2);
        for (int i = 0; i < n; i++) {
            blockThreadPool.execute(() -> contend(newLive()));
        }

//        for (int i = 0; i < n; i++) {
//            contend(newLive());
//        }
    }

    /**
     * 競(jìng)爭(zhēng)
     * @param ai
     */
    public void contend(MnistAgent ai){
        test(ai);
        quality(ai);
        double score = ai.getScore();
        if(score <= minScore){
            UI.put("無(wú)法生存",String.format("name = %s,  score = %s", ai.getName(),ai.getScore()));
            return;
        }

        Map.Entry<Double, MnistAgent> lastEntry;
        synchronized (lives) {
            lives.put(score,ai);
            if (lives.size() > max) {
                MnistAgent lastAI = lives.remove(lives.firstKey());
                UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore()));
            }
            lastEntry = lives.lastEntry();
            minScore = lives.firstKey();
        }

        Double lastScore = lastEntry.getKey();
        if(lastScore > maxScore){
            maxScore = lastScore;
            MnistAgent agent = lastEntry.getValue();
            validation(agent);
            UI.put("max驗(yàn)證",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore()));

            try {
                Warehouse.write(agent);
            } catch (IOException ex) {
                log.error("保存對(duì)象失敗",ex);
            }
        }
    }

    ArrayList<Double> scoreList = new ArrayList<>(100);
    ArrayList<Integer> avgList = new ArrayList<>();
    private void quality(MnistAgent ai) {
        synchronized (scoreList) {
            scoreList.add(ai.getScore());
            if (scoreList.size() >= 100) {
                double avg = scoreList.stream().mapToDouble(e -> e)
                        .average().getAsDouble();
                avgList.add((int) (avg * 1000));
                StringBuffer buffer = new StringBuffer();
                avgList.forEach(e -> buffer.append(e).append('\t'));
                UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString()));
                scoreList.clear();
            }
        }
    }

    /**
     * 隨機(jī)生成新智能體
     *
     * 完全隨機(jī)產(chǎn)生母本
     * 隨機(jī)從比目標(biāo)相同或更高評(píng)分中選擇父本
     *
     * 基因進(jìn)化在1%~10%之間進(jìn)行,評(píng)分越高基于越穩(wěn)定
     */
    public MnistAgent newLive(){
        double r = Math.random();
        //基因突變率
        double v = r / 11 + 0.01;
        //母本
        MnistAgent mAgent = getMother(r);
        //父本
        MnistAgent fAgent = getFather(r);

        int i = (int) (Math.random() * 3);
        INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna());
        switch (i){
            case 0:
                newDNA = evolution.other(newDNA,v);
                break;
            case 1:
                newDNA = evolution.mutation(newDNA,v,0.1);
                break;
            case 2:
                newDNA = evolution.substitution(newDNA,v);
                break;
        }
        return MnistAgent.create(newDNA);
    }

    /**
     * 父本只選擇比母本評(píng)分高的樣本
     * @param r
     * @return
     */
    private MnistAgent getFather(double r) {
        r += (Math.random() * (1-r));
        return getMother(r);
    }

    private MnistAgent getMother(double r) {
        int index = (int) (r * max);
        return getMnistAgent(index);
    }

    private MnistAgent getMnistAgent(int index) {
        synchronized (lives) {
            Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator();
            for (int i = 0; i < index; i++) {
                it.next();
            }

            return it.next().getValue();
        }
    }
}

主函數(shù)

@Slf4j
public class Program {

    public static void main(String[] args) {
        UI.put("開(kāi)始時(shí)間",new Date().toLocaleString());
        MnistEnv env = new MnistEnv(128);
        env.evolution(Integer.MAX_VALUE);
    }
}

運(yùn)行截圖

  dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別

看完上述內(nèi)容,你們對(duì)dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別有進(jìn)一步的了解嗎?如果還想了解更多知識(shí)或者相關(guān)內(nèi)容,請(qǐng)關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道,感謝大家的支持。

標(biāo)題名稱(chēng):dl4j如何使用遺傳神經(jīng)網(wǎng)絡(luò)完成手寫(xiě)數(shù)字識(shí)別
分享網(wǎng)址:http://chinadenli.net/article8/jpsoip.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供營(yíng)銷(xiāo)型網(wǎng)站建設(shè)軟件開(kāi)發(fā)GoogleChatGPT品牌網(wǎng)站制作

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)

成都網(wǎng)站建設(shè)公司