前言:

Keras是一個(gè)由Python編寫的開源人工神經(jīng)網(wǎng)絡(luò)庫,可以作為Tensorflow、Microsoft-CNTK和Theano的高階應(yīng)用程序接口,進(jìn)行深度學(xué)習(xí)模型的設(shè)計(jì)、調(diào)試、評(píng)估、應(yīng)用和可視化。
Theano于2008年誕生于蒙特利爾理工學(xué)院,其派生出了大量的深度學(xué)習(xí)Python軟件包,最著名的包括Blocks和Keras。Theano的核心是一個(gè)數(shù)學(xué)表達(dá)式的編譯器,它知道如何獲取你的結(jié)構(gòu),并使之成為一個(gè)使用numpy、高效本地庫的高效代碼,如BLAS和本地代碼(C++)在CPU或GPU上盡可能快地運(yùn)行。它是為深度學(xué)習(xí)中處理大型神經(jīng)網(wǎng)絡(luò)算法所需的計(jì)算而專門設(shè)計(jì),是這類庫的首創(chuàng)之一(發(fā)展始于2007年),被認(rèn)為是深度學(xué)習(xí)研究和開發(fā)的行業(yè)標(biāo)準(zhǔn)。
TensorFlow是一個(gè)基于數(shù)據(jù)流編程(dataflow programming)的符號(hào)數(shù)學(xué)系統(tǒng),被廣泛應(yīng)用于各類機(jī)器學(xué)習(xí)(machine learning)算法的編程實(shí)現(xiàn),其前身是谷歌的神經(jīng)網(wǎng)絡(luò)算法庫DistBelief。
Tensorflow擁有多層級(jí)結(jié)構(gòu),可部署于各類服務(wù)器、PC終端和網(wǎng)頁并支持GPU和TPU高性能數(shù)值計(jì)算,被廣泛應(yīng)用于谷歌內(nèi)部的產(chǎn)品開發(fā)和各領(lǐng)域的科學(xué)研究。
代碼:
</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">
# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
def th3tf( model):
import tensorflow as tf
ops = []
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
ops.append(tf.assign(layer.W, converted_w).op)
K.get_session().run(ops)
return model
def tf2th(model):
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
K.set_value(layer.W, converted_w)
return model
def conv_layer_converted(tf_weights, th_weights, m = 0):
"""
:param tf_weights:
:param th_weights:
:param m: 0-tf2th, 1-th3tf
:return:
"""
if m == 0: # tf2th
tc = keras_text_classifier(weights_path=tf_weights)
model = tc.loadmodel()
model = tf2th(model)
model.save_weights(th_weights)
elif m == 1: # th3tf
tc = keras_text_classifier(weights_path=th_weights)
model = tc.loadmodel()
model = th3tf(model)
model.save_weights(tf_weights)
else:
print("0-tf2th, 1-th3tf")
return
if __name__ == '__main__':
if len(sys.argv) < 4:
print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
sys.exit(0)
tf_weights = sys.argv[1]
th_weights = sys.argv[2]
m = int(sys.argv[3])
conv_layer_converted(tf_weights, th_weights, m)
名稱欄目:keras實(shí)現(xiàn)tensorflow與theano相互轉(zhuǎn)換的方法-創(chuàng)新互聯(lián)
網(wǎng)址分享:http://chinadenli.net/article14/hhoge.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供響應(yīng)式網(wǎng)站、用戶體驗(yàn)、商城網(wǎng)站、面包屑導(dǎo)航、App設(shè)計(jì)、網(wǎng)站建設(shè)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)