這篇文章給大家介紹Pytorch轉(zhuǎn)變Caffe再轉(zhuǎn)變om模型轉(zhuǎn)換流程是怎樣的,內(nèi)容非常詳細(xì),感興趣的小伙伴們可以參考借鑒,希望對大家能有所幫助。

公司主營業(yè)務(wù):成都網(wǎng)站設(shè)計(jì)、網(wǎng)站制作、外貿(mào)營銷網(wǎng)站建設(shè)、移動網(wǎng)站開發(fā)等業(yè)務(wù)。幫助企業(yè)客戶真正實(shí)現(xiàn)互聯(lián)網(wǎng)宣傳,提高企業(yè)的競爭能力。成都創(chuàng)新互聯(lián)公司是一支青春激揚(yáng)、勤奮敬業(yè)、活力青春激揚(yáng)、勤奮敬業(yè)、活力澎湃、和諧高效的團(tuán)隊(duì)。公司秉承以“開放、自由、嚴(yán)謹(jǐn)、自律”為核心的企業(yè)文化,感謝他們對我們的高要求,感謝他們從不同領(lǐng)域給我們帶來的挑戰(zhàn),讓我們激情的團(tuán)隊(duì)有機(jī)會用頭腦與智慧不斷的給客戶帶來驚喜。成都創(chuàng)新互聯(lián)公司推出貞豐免費(fèi)做網(wǎng)站回饋大家。
Baseline:PytorchToCaffe
主要功能代碼在:
PytorchToCaffe +-- Caffe | +-- caffe.proto | +-- layer_param.py +-- example | +-- resnet_pytorch_2_caffe.py +-- pytorch_to_caffe.py
直接使用可以參考resnet_pytorch_2_caffe.py,如果網(wǎng)絡(luò)中的操作Baseline中都已經(jīng)實(shí)現(xiàn),則可以直接轉(zhuǎn)換到Caffe模型。
如果遇到?jīng)]有實(shí)現(xiàn)的操作,則要分為兩種情況來考慮。
以arg_max為例分享一下添加操作的方式。
首先要查看Caffe中對應(yīng)層的參數(shù):caffe.proto為對應(yīng)版本caffe層與參數(shù)的定義,可以看到ArgMax定義了out_max_val、top_k、axis三個參數(shù):
message ArgMaxParameter {
// If true produce pairs (argmax, maxval)
optional bool out_max_val = 1 [default = false];
optional uint32 top_k = 2 [default = 1];
// The axis along which to maximise -- may be negative to index from the
// end (e.g., -1 for the last axis).
// By default ArgMaxLayer maximizes over the flattened trailing dimensions
// for each index of the first / num dimension.
optional int32 axis = 3;
}與Caffe算子邊界中的參數(shù)是一致的。
layer_param.py構(gòu)建了具體轉(zhuǎn)換時(shí)參數(shù)類的實(shí)例,實(shí)現(xiàn)了操作參數(shù)從Pytorch到Caffe的傳遞:
def argmax_param(self, out_max_val=None, top_k=None, dim=1): argmax_param = pb.ArgMaxParameter() if out_max_val is not None: argmax_param.out_max_val = out_max_val if top_k is not None: argmax_param.top_k = top_k if dim is not None: argmax_param.axis = dim self.param.argmax_param.CopyFrom(argmax_param)
pytorch_to_caffe.py中定義了Rp類,用來實(shí)現(xiàn)Pytorch操作到Caffe操作的變換:
class Rp(object):
def __init__(self, raw, replace, **kwargs):
self.obj = replace
self.raw = raw
def __call__(self, *args, **kwargs):
if not NET_INITTED:
return self.raw(*args, **kwargs)
for stack in traceback.walk_stack(None):
if 'self' in stack[0].f_locals:
layer = stack[0].f_locals['self']
if layer in layer_names:
log.pytorch_layer_name = layer_names[layer]
print('984', layer_names[layer])
break
out = self.obj(self.raw, *args, **kwargs)
return out在添加操作時(shí),要使用Rp類替換操作:
torch.argmax = Rp(torch.argmax, torch_argmax)
接下來,要具體實(shí)現(xiàn)該操作:
def torch_argmax(raw, input, dim=1): x = raw(input, dim=dim) layer_name = log.add_layer(name='argmax') top_blobs = log.add_blobs([x], name='argmax_blob'.format(type)) layer = caffe_net.Layer_param(name=layer_name, type='ArgMax', bottom=[log.blobs(input)], top=top_blobs) layer.argmax_param(dim=dim) log.cnet.add_layer(layer) return x
即實(shí)現(xiàn)了argmax操作Pytorch到Caffe的轉(zhuǎn)換。
如果要轉(zhuǎn)換的操作在Caffe中無直接對應(yīng)的層實(shí)現(xiàn),解決思路主要有兩個:
1)在Pytorch中將不支持的操作分解為支持的操作:
如nn.InstanceNorm2d,實(shí)例歸一化在轉(zhuǎn)換時(shí)是用BatchNorm做的,不支持 affine=True 或者track_running_stats=True,默認(rèn)use_global_stats:false,但om轉(zhuǎn)換時(shí)use_global_stats必須為true,所以可以轉(zhuǎn)到Caffe,但再轉(zhuǎn)om不友好。
InstanceNorm是在featuremap的每個Channel上進(jìn)行歸一化操作,因此,可以實(shí)現(xiàn)nn.InstanceNorm2d為:
class InstanceNormalization(nn.Module): def __init__(self, dim, eps=1e-5): super(InstanceNormalization, self).__init__() self.gamma = nn.Parameter(torch.FloatTensor(dim)) self.beta = nn.Parameter(torch.FloatTensor(dim)) self.eps = eps self._reset_parameters() def _reset_parameters(self): self.gamma.data.uniform_() self.beta.data.zero_() def __call__(self, x): n = x.size(2) * x.size(3) t = x.view(x.size(0), x.size(1), n) mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) gamma_broadcast = self.gamma.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x) beta_broadcast = self.beta.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x) out = (x - mean) / torch.sqrt(var + self.eps) out = out * gamma_broadcast + beta_broadcast return out
但在驗(yàn)證HiLens Caffe算子邊界中發(fā)現(xiàn), om模型轉(zhuǎn)換不支持Channle維度之外的求和或求均值操作,為了規(guī)避這個操作,我們可以通過支持的算子重新實(shí)現(xiàn)nn.InstanceNorm2d:
class InstanceNormalization(nn.Module): def __init__(self, dim, eps=1e-5): super(InstanceNormalization, self).__init__() self.gamma = torch.FloatTensor(dim) self.beta = torch.FloatTensor(dim) self.eps = eps self.adavg = nn.AdaptiveAvgPool2d(1) def forward(self, x): n, c, h, w = x.shape mean = nn.Upsample(scale_factor=h)(self.adavg(x)) var = nn.Upsample(scale_factor=h)(self.adavg((x - mean).pow(2))) gamma_broadcast = self.gamma.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x) beta_broadcast = self.beta.unsqueeze(1).unsqueeze(1).unsqueeze(0).expand_as(x) out = (x - mean) / torch.sqrt(var + self.eps) out = out * gamma_broadcast + beta_broadcast return out
經(jīng)過驗(yàn)證,與原操作等價(jià),可以轉(zhuǎn)為Caffe模型
2)在Caffe中通過利用現(xiàn)有操作實(shí)現(xiàn):
在Pytorch轉(zhuǎn)Caffe的過程中發(fā)現(xiàn),如果存在featuremap + 6這種涉及到常數(shù)的操作,轉(zhuǎn)換過程中會出現(xiàn)找不到blob的問題。我們首先查看pytorch_to_caffe.py中add操作的具體轉(zhuǎn)換方法:
def _add(input, *args): x = raw__add__(input, *args) if not NET_INITTED: return x layer_name = log.add_layer(name='add') top_blobs = log.add_blobs([x], name='add_blob') if log.blobs(args[0]) == None: log.add_blobs([args[0]], name='extra_blob') else: layer = caffe_net.Layer_param(name=layer_name, type='Eltwise', bottom=[log.blobs(input),log.blobs(args[0])], top=top_blobs) layer.param.eltwise_param.operation = 1 # sum is 1 log.cnet.add_layer(layer) return x
可以看到對于blob不存在的情況進(jìn)行了判斷,我們只需要在log.blobs(args[0]) == None條件下進(jìn)行修改,一個自然的想法是利用Scale層實(shí)現(xiàn)add操作:
def _add(input, *args): x = raw__add__(input, *args) if not NET_INITTED: return x layer_name = log.add_layer(name='add') top_blobs = log.add_blobs([x], name='add_blob') if log.blobs(args[0]) == None: layer = caffe_net.Layer_param(name=layer_name, type='Scale', bottom=[log.blobs(input)], top=top_blobs) layer.param.scale_param.bias_term = True weight = torch.ones((input.shape[1])) bias = torch.tensor(args[0]).squeeze().expand_as(weight) layer.add_data(weight.cpu().data.numpy(), bias.cpu().data.numpy()) log.cnet.add_layer(layer) else: layer = caffe_net.Layer_param(name=layer_name, type='Eltwise', bottom=[log.blobs(input), log.blobs(args[0])], top=top_blobs) layer.param.eltwise_param.operation = 1 # sum is 1 log.cnet.add_layer(layer) return x
類似的,featuremap * 6這種簡單乘法也可以通過同樣的方法實(shí)現(xiàn)。
Pooling:Pytorch默認(rèn) ceil_mode=false,Caffe默認(rèn) ceil_mode=true,可能會導(dǎo)致維度變化,如果出現(xiàn)尺寸不匹配的問題可以檢查一下Pooling參數(shù)是否正確。另外,雖然文檔上沒有看到,但是 kernel_size > 32 后模型雖然可以轉(zhuǎn)換,但推理會報(bào)錯,這時(shí)可以分兩層進(jìn)行Pooling操作。
Upsample :om邊界算子中的Upsample 層scale_factor參數(shù)必須是int,不能是size。如果已有模型參數(shù)為size也會正常跑完P(guān)ytorch轉(zhuǎn)Caffe的流程,但此時(shí)Upsample參數(shù)是空的。參數(shù)為size的情況可以考慮轉(zhuǎn)為scale_factor或用Deconvolution來實(shí)現(xiàn)。
Transpose2d:Pytorch中 output_padding 參數(shù)會加在輸出的大小上,但Caffe不會,輸出特征圖相對會變小,此時(shí)反卷積之后的featuremap會變大一點(diǎn),可以通過Crop層進(jìn)行裁剪,使其大小與Pytorch對應(yīng)層一致。另外,om中反卷積推理速度較慢,最好是不要使用,可以用Upsample+Convolution替代。
Pad:Pytorch中Pad操作很多樣,但Caffe中只能進(jìn)行H與W維度上的對稱pad,如果Pytorch網(wǎng)絡(luò)中有h = F.pad(x, (1, 2, 1, 2), "constant", 0)這種不對稱的pad操作,解決思路為:
如果不對稱pad的層不存在后續(xù)的維度不匹配的問題,可以先判斷一下pad對結(jié)果的影響,一些任務(wù)受pad的影響很小,那么就不需要修改。
如果存在維度不匹配的問題,可以考慮按照較大的參數(shù)充分pad之后進(jìn)行Crop,或是將前后兩個(0, 0, 1, 1)與(1, 1, 0, 0)的pad合為一個(1, 1, 1, 1),這要看具體的網(wǎng)絡(luò)結(jié)構(gòu)確定。
如果是Channel維度上的pad如F.pad(x, (0, 0, 0, 0, 0, channel_pad), "constant", 0),可以考慮零卷積后cat到featuremap上:
zero = nn.Conv2d(in_channels, self.channel_pad, kernel_size=3, padding=1, bias=False) nn.init.constant(self.zero.weight, 0) pad_tensor = zero(x) x = torch.cat([x, pad_tensor], dim=1)
一些操作可以轉(zhuǎn)到Caffe,但om并不支持標(biāo)準(zhǔn)Caffe的所有操作,如果要再轉(zhuǎn)到om要對照文檔確認(rèn)好邊界算子。
關(guān)于Pytorch轉(zhuǎn)變Caffe再轉(zhuǎn)變om模型轉(zhuǎn)換流程是怎樣的就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,可以學(xué)到更多知識。如果覺得文章不錯,可以把它分享出去讓更多的人看到。
網(wǎng)頁題目:Pytorch轉(zhuǎn)變Caffe再轉(zhuǎn)變om模型轉(zhuǎn)換流程是怎樣的
網(wǎng)頁地址:http://chinadenli.net/article24/jijsje.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供網(wǎng)站內(nèi)鏈、網(wǎng)站營銷、網(wǎng)頁設(shè)計(jì)公司、品牌網(wǎng)站建設(shè)、手機(jī)網(wǎng)站建設(shè)、軟件開發(fā)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)