本篇文章給大家分享的是有關(guān)使用python實(shí)現(xiàn)AdaBoost算法的方法,小編覺(jué)得挺實(shí)用的,因此分享給大家學(xué)習(xí),希望大家閱讀完這篇文章后可以有所收獲,話不多說(shuō),跟著小編一起來(lái)看看吧。
代碼
''' 數(shù)據(jù)集:Mnist 訓(xùn)練集數(shù)量:60000(實(shí)際使用:10000) 測(cè)試集數(shù)量:10000(實(shí)際使用:1000) 層數(shù):40 ------------------------------ 運(yùn)行結(jié)果: 正確率:97% 運(yùn)行時(shí)長(zhǎng):65m ''' import time import numpy as np def loadData(fileName): ''' 加載文件 :param fileName:要加載的文件路徑 :return: 數(shù)據(jù)集和標(biāo)簽集 ''' # 存放數(shù)據(jù)及標(biāo)記 dataArr = [] labelArr = [] # 讀取文件 fr = open(fileName) # 遍歷文件中的每一行 for line in fr.readlines(): # 獲取當(dāng)前行,并按“,”切割成字段放入列表中 # strip:去掉每行字符串首尾指定的字符(默認(rèn)空格或換行符) # split:按照指定的字符將字符串切割成每個(gè)字段,返回列表形式 curLine = line.strip().split(',') # 將每行中除標(biāo)記外的數(shù)據(jù)放入數(shù)據(jù)集中(curLine[0]為標(biāo)記信息) # 在放入的同時(shí)將原先字符串形式的數(shù)據(jù)轉(zhuǎn)換為整型 # 此外將數(shù)據(jù)進(jìn)行了二值化處理,大于128的轉(zhuǎn)換成1,小于的轉(zhuǎn)換成0,方便后續(xù)計(jì)算 dataArr.append([int(int(num) > 128) for num in curLine[1:]]) # 將標(biāo)記信息放入標(biāo)記集中 # 放入的同時(shí)將標(biāo)記轉(zhuǎn)換為整型 # 轉(zhuǎn)換成二分類(lèi)任務(wù) # 標(biāo)簽0設(shè)置為1,反之為-1 if int(curLine[0]) == 0: labelArr.append(1) else: labelArr.append(-1) # 返回?cái)?shù)據(jù)集和標(biāo)記 return dataArr, labelArr def calc_e_Gx(trainDataArr, trainLabelArr, n, div, rule, D): ''' 計(jì)算分類(lèi)錯(cuò)誤率 :param trainDataArr:訓(xùn)練數(shù)據(jù)集數(shù)字 :param trainLabelArr: 訓(xùn)練標(biāo)簽集數(shù)組 :param n: 要操作的特征 :param div:劃分點(diǎn) :param rule:正反例標(biāo)簽 :param D:權(quán)值分布D :return:預(yù)測(cè)結(jié)果, 分類(lèi)誤差率 ''' # 初始化分類(lèi)誤差率為0 e = 0 # 將訓(xùn)練數(shù)據(jù)矩陣中特征為n的那一列單獨(dú)剝出來(lái)做成數(shù)組。因?yàn)槠渌匚覀儾⒉恍枰? # 直接對(duì)龐大的訓(xùn)練集進(jìn)行操作的話會(huì)很慢 x = trainDataArr[:, n] # 同樣將標(biāo)簽也轉(zhuǎn)換成數(shù)組格式,x和y的轉(zhuǎn)換只是單純?yōu)榱颂岣哌\(yùn)行速度 # 測(cè)試過(guò)相對(duì)直接操作而言性能提升很大 y = trainLabelArr predict = [] # 依據(jù)小于和大于的標(biāo)簽依據(jù)實(shí)際情況會(huì)不同,在這里直接進(jìn)行設(shè)置 if rule == 'LisOne': L = 1 H = -1 else: L = -1 H = 1 # 遍歷所有樣本的特征m for i in range(trainDataArr.shape[0]): if x[i] < div: # 如果小于劃分點(diǎn),則預(yù)測(cè)為L(zhǎng) # 如果設(shè)置小于div為1,那么L就是1, # 如果設(shè)置小于div為-1,L就是-1 predict.append(L) # 如果預(yù)測(cè)錯(cuò)誤,分類(lèi)錯(cuò)誤率要加上該分錯(cuò)的樣本的權(quán)值(8.1式) if y[i] != L: e += D[i] elif x[i] >= div: # 與上面思想一樣 predict.append(H) if y[i] != H: e += D[i] # 返回預(yù)測(cè)結(jié)果和分類(lèi)錯(cuò)誤率e # 預(yù)測(cè)結(jié)果其實(shí)是為了后面做準(zhǔn)備的,在算法8.1第四步式8.4中exp內(nèi)部有個(gè)Gx,要用在那個(gè)地方 # 以此來(lái)更新新的D return np.array(predict), e def createSigleBoostingTree(trainDataArr, trainLabelArr, D): ''' 創(chuàng)建單層提升樹(shù) :param trainDataArr:訓(xùn)練數(shù)據(jù)集數(shù)組 :param trainLabelArr: 訓(xùn)練標(biāo)簽集數(shù)組 :param D: 算法8.1中的D :return: 創(chuàng)建的單層提升樹(shù) ''' # 獲得樣本數(shù)目及特征數(shù)量 m, n = np.shape(trainDataArr) # 單層樹(shù)的字典,用于存放當(dāng)前層提升樹(shù)的參數(shù) # 也可以認(rèn)為該字典代表了一層提升樹(shù) sigleBoostTree = {} # 初始化分類(lèi)誤差率,分類(lèi)誤差率在算法8.1步驟(2)(b)有提到 # 誤差率最高也只能100%,因此初始化為1 sigleBoostTree['e'] = 1 # 對(duì)每一個(gè)特征進(jìn)行遍歷,尋找用于劃分的最合適的特征 for i in range(n): # 因?yàn)樘卣饕呀?jīng)經(jīng)過(guò)二值化,只能為0和1,因此分切分時(shí)分為-0.5, 0.5, 1.5三擋進(jìn)行切割 for div in [-0.5, 0.5, 1.5]: # 在單個(gè)特征內(nèi)對(duì)正反例進(jìn)行劃分時(shí),有兩種情況: # 可能是小于某值的為1,大于某值得為-1,也可能小于某值得是-1,反之為1 # 因此在尋找最佳提升樹(shù)的同時(shí)對(duì)于兩種情況也需要遍歷運(yùn)行 # LisOne:Low is one:小于某值得是1 # HisOne:High is one:大于某值得是1 for rule in ['LisOne', 'HisOne']: # 按照第i個(gè)特征,以值div進(jìn)行切割,進(jìn)行當(dāng)前設(shè)置得到的預(yù)測(cè)和分類(lèi)錯(cuò)誤率 Gx, e = calc_e_Gx(trainDataArr, trainLabelArr, i, div, rule, D) # 如果分類(lèi)錯(cuò)誤率e小于當(dāng)前最小的e,那么將它作為最小的分類(lèi)錯(cuò)誤率保存 if e < sigleBoostTree['e']: sigleBoostTree['e'] = e # 同時(shí)也需要存儲(chǔ)最優(yōu)劃分點(diǎn)、劃分規(guī)則、預(yù)測(cè)結(jié)果、特征索引 # 以便進(jìn)行D更新和后續(xù)預(yù)測(cè)使用 sigleBoostTree['div'] = div sigleBoostTree['rule'] = rule sigleBoostTree['Gx'] = Gx sigleBoostTree['feature'] = i # 返回單層的提升樹(shù) return sigleBoostTree def createBosstingTree(trainDataList, trainLabelList, treeNum=50): ''' 創(chuàng)建提升樹(shù) 創(chuàng)建算法依據(jù)“8.1.2 AdaBoost算法” 算法8.1 :param trainDataList:訓(xùn)練數(shù)據(jù)集 :param trainLabelList: 訓(xùn)練測(cè)試集 :param treeNum: 樹(shù)的層數(shù) :return: 提升樹(shù) ''' # 將數(shù)據(jù)和標(biāo)簽轉(zhuǎn)化為數(shù)組形式 trainDataArr = np.array(trainDataList) trainLabelArr = np.array(trainLabelList) # 沒(méi)增加一層數(shù)后,當(dāng)前最終預(yù)測(cè)結(jié)果列表 finallpredict = [0] * len(trainLabelArr) # 獲得訓(xùn)練集數(shù)量以及特征個(gè)數(shù) m, n = np.shape(trainDataArr) # 依據(jù)算法8.1步驟(1)初始化D為1/N D = [1 / m] * m # 初始化提升樹(shù)列表,每個(gè)位置為一層 tree = [] # 循環(huán)創(chuàng)建提升樹(shù) for i in range(treeNum): # 得到當(dāng)前層的提升樹(shù) curTree = createSigleBoostingTree(trainDataArr, trainLabelArr, D) # 根據(jù)式8.2計(jì)算當(dāng)前層的alpha alpha = 1 / 2 * np.log((1 - curTree['e']) / curTree['e']) # 獲得當(dāng)前層的預(yù)測(cè)結(jié)果,用于下一步更新D Gx = curTree['Gx'] # 依據(jù)式8.4更新D # 考慮到該式每次只更新D中的一個(gè)w,要循環(huán)進(jìn)行更新知道所有w更新結(jié)束會(huì)很復(fù)雜(其實(shí) # 不是時(shí)間上的復(fù)雜,只是讓人感覺(jué)每次單獨(dú)更新一個(gè)很累),所以該式以向量相乘的形式, # 一個(gè)式子將所有w全部更新完。 # 該式需要線性代數(shù)基礎(chǔ),如果不太熟練建議補(bǔ)充相關(guān)知識(shí),當(dāng)然了,單獨(dú)更新w也一點(diǎn)問(wèn)題 # 沒(méi)有 # np.multiply(trainLabelArr, Gx):exp中的y*Gm(x),結(jié)果是一個(gè)行向量,內(nèi)部為yi*Gm(xi) # np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx)):上面求出來(lái)的行向量?jī)?nèi)部全體 # 成員再乘以-αm,然后取對(duì)數(shù),和書(shū)上式子一樣,只不過(guò)書(shū)上式子內(nèi)是一個(gè)數(shù),這里是一個(gè)向量 # D是一個(gè)行向量,取代了式中的wmi,然后D求和為Zm # 書(shū)中的式子最后得出來(lái)一個(gè)數(shù)w,所有數(shù)w組合形成新的D # 這里是直接得到一個(gè)向量,向量?jī)?nèi)元素是所有的w # 本質(zhì)上結(jié)果是相同的 D = np.multiply(D, np.exp(-1 * alpha * np.multiply(trainLabelArr, Gx))) / sum(D) # 在當(dāng)前層參數(shù)中增加alpha參數(shù),預(yù)測(cè)的時(shí)候需要用到 curTree['alpha'] = alpha # 將當(dāng)前層添加到提升樹(shù)索引中。 tree.append(curTree) # -----以下代碼用來(lái)輔助,可以去掉--------------- # 根據(jù)8.6式將結(jié)果加上當(dāng)前層乘以α,得到目前的最終輸出預(yù)測(cè) finallpredict += alpha * Gx # 計(jì)算當(dāng)前最終預(yù)測(cè)輸出與實(shí)際標(biāo)簽之間的誤差 error = sum([1 for i in range(len(trainDataList)) if np.sign(finallpredict[i]) != trainLabelArr[i]]) # 計(jì)算當(dāng)前最終誤差率 finallError = error / len(trainDataList) # 如果誤差為0,提前退出即可,因?yàn)闆](méi)有必要再計(jì)算算了 if finallError == 0: return tree # 打印一些信息 print('iter:%d:%d, sigle error:%.4f, finall error:%.4f' % (i, treeNum, curTree['e'], finallError)) # 返回整個(gè)提升樹(shù) return tree def predict(x, div, rule, feature): ''' 輸出單獨(dú)層預(yù)測(cè)結(jié)果 :param x: 預(yù)測(cè)樣本 :param div: 劃分點(diǎn) :param rule: 劃分規(guī)則 :param feature: 進(jìn)行操作的特征 :return: ''' # 依據(jù)劃分規(guī)則定義小于及大于劃分點(diǎn)的標(biāo)簽 if rule == 'LisOne': L = 1 H = -1 else: L = -1 H = 1 # 判斷預(yù)測(cè)結(jié)果 if x[feature] < div: return L else: return H def test(testDataList, testLabelList, tree): ''' 測(cè)試 :param testDataList:測(cè)試數(shù)據(jù)集 :param testLabelList: 測(cè)試標(biāo)簽集 :param tree: 提升樹(shù) :return: 準(zhǔn)確率 ''' # 錯(cuò)誤率計(jì)數(shù)值 errorCnt = 0 # 遍歷每一個(gè)測(cè)試樣本 for i in range(len(testDataList)): # 預(yù)測(cè)結(jié)果值,初始為0 result = 0 # 依據(jù)算法8.1式8.6 # 預(yù)測(cè)式子是一個(gè)求和式,對(duì)于每一層的結(jié)果都要進(jìn)行一次累加 # 遍歷每層的樹(shù) for curTree in tree: # 獲取該層參數(shù) div = curTree['div'] rule = curTree['rule'] feature = curTree['feature'] alpha = curTree['alpha'] # 將當(dāng)前層結(jié)果加入預(yù)測(cè)中 result += alpha * predict(testDataList[i], div, rule, feature) # 預(yù)測(cè)結(jié)果取sign值,如果大于0 sign為1,反之為0 if np.sign(result) != testLabelList[i]: errorCnt += 1 # 返回準(zhǔn)確率 return 1 - errorCnt / len(testDataList) if __name__ == '__main__': # 開(kāi)始時(shí)間 start = time.time() # 獲取訓(xùn)練集 print('start read transSet') trainDataList, trainLabelList = loadData('../Mnist/mnist_train.csv') # 獲取測(cè)試集 print('start read testSet') testDataList, testLabelList = loadData('../Mnist/mnist_test.csv') # 創(chuàng)建提升樹(shù) print('start init train') tree = createBosstingTree(trainDataList[:10000], trainLabelList[:10000], 40) # 測(cè)試 print('start to test') accuracy = test(testDataList[:1000], testLabelList[:1000], tree) print('the accuracy is:%d' % (accuracy * 100), '%') # 結(jié)束時(shí)間 end = time.time() print('time span:', end - start)
新聞名稱(chēng):使用python實(shí)現(xiàn)AdaBoost算法的方法-創(chuàng)新互聯(lián)
當(dāng)前鏈接:http://chinadenli.net/article34/epcpe.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供軟件開(kāi)發(fā)、定制網(wǎng)站、品牌網(wǎng)站設(shè)計(jì)、微信公眾號(hào)、定制開(kāi)發(fā)、搜索引擎優(yōu)化
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶(hù)投稿、用戶(hù)轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容