這篇文章主要為大家展示了Pytorch如何轉(zhuǎn)tflite,內(nèi)容簡(jiǎn)而易懂,希望大家可以學(xué)習(xí)一下,學(xué)習(xí)完之后肯定會(huì)有收獲的,下面讓小編帶大家一起來(lái)看看吧。
目標(biāo)是想把在服務(wù)器上用pytorch訓(xùn)練好的模型轉(zhuǎn)換為可以在移動(dòng)端運(yùn)行的tflite模型。
最直接的思路是想把pytorch模型轉(zhuǎn)換為tensorflow的模型,然后轉(zhuǎn)換為tflite。但是這個(gè)轉(zhuǎn)換目前沒(méi)有發(fā)現(xiàn)比較靠譜的方法。
經(jīng)過(guò)調(diào)研發(fā)現(xiàn)最新的tflite已經(jīng)支持直接從keras模型的轉(zhuǎn)換,所以可以采用keras作為中間轉(zhuǎn)換的橋梁,這樣就能充分利用keras高層API的便利性。
轉(zhuǎn)換的基本思想就是用pytorch中的各層網(wǎng)絡(luò)的權(quán)重取出來(lái)后直接賦值給keras網(wǎng)絡(luò)中的對(duì)應(yīng)layer層的權(quán)重。
轉(zhuǎn)換為Keras模型后,再通過(guò)tf.contrib.lite.TocoConverter把模型直接轉(zhuǎn)為tflite.
下面是一個(gè)例子,假設(shè)轉(zhuǎn)換的是一個(gè)兩層的CNN網(wǎng)絡(luò)。
import tensorflow as tf from tensorflow import keras import numpy as np import torch from torchvision import models import torch.nn as nn # import torch.nn.functional as F from torch.autograd import Variable class PytorchNet(nn.Module): def __init__(self): super(PytorchNet, self).__init__() conv1 = nn.Sequential( nn.Conv2d(3, 32, 3, 2), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)) conv2 = nn.Sequential( nn.Conv2d(32, 64, 3, 1, groups=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)) self.feature = nn.Sequential(conv1, conv2) self.init_weights() def forward(self, x): return self.feature(x) def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight.data, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() if isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def KerasNet(input_shape=(224, 224, 3)): image_input = keras.layers.Input(shape=input_shape) # conv1 network = keras.layers.Conv2D( 32, (3, 3), strides=(2, 2), padding="valid")(image_input) network = keras.layers.BatchNormalization( trainable=False, fused=False)(network) network = keras.layers.Activation("relu")(network) network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network) # conv2 network = keras.layers.Conv2D( 64, (3, 3), strides=(1, 1), padding="valid")(network) network = keras.layers.BatchNormalization( trainable=False, fused=True)(network) network = keras.layers.Activation("relu")(network) network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network) model = keras.Model(inputs=image_input, outputs=network) return model class PytorchToKeras(object): def __init__(self, pModel, kModel): super(PytorchToKeras, self) self.__source_layers = [] self.__target_layers = [] self.pModel = pModel self.kModel = kModel tf.keras.backend.set_learning_phase(0) def __retrieve_k_layers(self): for i, layer in enumerate(self.kModel.layers): if len(layer.weights) > 0: self.__target_layers.append(i) def __retrieve_p_layers(self, input_size): input = torch.randn(input_size) input = Variable(input.unsqueeze(0)) hooks = [] def add_hooks(module): def hook(module, input, output): if hasattr(module, "weight"): # print(module) self.__source_layers.append(module) if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel: hooks.append(module.register_forward_hook(hook)) self.pModel.apply(add_hooks) self.pModel(input) for hook in hooks: hook.remove() def convert(self, input_size): self.__retrieve_k_layers() self.__retrieve_p_layers(input_size) for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)): print(source_layer) weight_size = len(source_layer.weight.data.size()) transpose_dims = [] for i in range(weight_size): transpose_dims.append(weight_size - i - 1) if isinstance(source_layer, nn.Conv2d): transpose_dims = [2,3,1,0] self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy( ).transpose(transpose_dims), source_layer.bias.data.numpy()]) elif isinstance(source_layer, nn.BatchNorm2d): self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(), source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()]) def save_model(self, output_file): self.kModel.save(output_file) def save_weights(self, output_file): self.kModel.save_weights(output_file, save_format='h6') pytorch_model = PytorchNet() keras_model = KerasNet(input_shape=(224, 224, 3)) torch.save(pytorch_model, 'test.pth') #Load the pretrained model pytorch_model = torch.load('test.pth') # #Time to transfer weights converter = PytorchToKeras(pytorch_model, keras_model) converter.convert((3, 224, 224)) # #Save the converted keras model for later use # converter.save_weights("keras.h6") converter.save_model("keras_model.h6") # convert keras model to tflite model converter = tf.contrib.lite.TocoConverter.from_keras_model_file( "keras_model.h6") tflite_model = converter.convert() open("convert_model.tflite", "wb").write(tflite_model)
網(wǎng)站標(biāo)題:Pytorch如何轉(zhuǎn)tflite-創(chuàng)新互聯(lián)
網(wǎng)站鏈接:http://chinadenli.net/article8/dcjgip.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供營(yíng)銷(xiāo)型網(wǎng)站建設(shè)、自適應(yīng)網(wǎng)站、定制網(wǎng)站、ChatGPT、軟件開(kāi)發(fā)、動(dòng)態(tài)網(wǎng)站
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話(huà):028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容