1. 環(huán)境配置
本實(shí)驗(yàn)使用操作系統(tǒng):Ubuntu 18.04.3 LTS 4.15.0-29-generic GNU/Linux操作系統(tǒng)。
1.1 查看CUDA版本
cat /usr/local/cuda/version.txt
輸出:
CUDA Version 10.0.130*
1.2 查看 cudnn版本
cat /usr/local/cuda/include/cudnn.h | grep CUDNN_MAJOR -A 2
輸出:
#define CUDNN_MINOR 6
#define CUDNN_PATCHLEVEL 3
--
#define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL)
如果沒有安裝 cuda 和 cudnn,到官網(wǎng)根自己的 GPU 型號(hào)版本安裝即可
1.3 安裝tensorflow-gpu
通過Anaconda創(chuàng)建虛擬環(huán)境來安裝tensorflow-gpu(Anaconda安裝步驟就不說了)
創(chuàng)建虛擬環(huán)境
虛擬環(huán)境名為:tensorflow
conda create -n tensorflow python=3.7.1
進(jìn)入虛擬環(huán)境
下次使用也可以通過此命令進(jìn)入虛擬環(huán)境
source activate tensorflow
安裝tensorflow-gpu
不推薦直接pip install tensorflow-gpu 因?yàn)樗俣缺容^慢??梢詮亩拱甑溺R像中下載,速度還是很快的。https://pypi.doubanio.com/simple/tensorflow-gpu/
找到自己適用的版本(cp37表示python版本為3.7)
然后通過pip install 安裝
pip install https://pypi.doubanio.com/packages/15/21/17f941058556b67ce6d1e3f0e0932c9c2deaf457e3d45eecd93f2c20827d/tensorflow_gpu-1.14.0rc1-cp37-cp37m-manylinux1_x86_64.whl
我選擇了1.14.0的tensorflow-gpu linux版本,python版本為3.7。使用BERT的話,tensorflow-gpu版本必須大于1.11.0。同時(shí),不建議選擇2.0版本,2.0版本好像修改了一些方法,還需要自己手動(dòng)修改代碼
環(huán)境測(cè)試
在tensorflow虛擬環(huán)境中,python命令進(jìn)入Python環(huán)境中,輸入import tensorflow,看是否能成功導(dǎo)入
2. 準(zhǔn)備工作
2.1 預(yù)訓(xùn)練模型下載
Bert-base Chinese
BERT-wwm :由哈工大和訊飛聯(lián)合實(shí)驗(yàn)室發(fā)布的,效果比Bert-base Chinese要好一些(鏈接地址為訊飛云,密碼:mva8。無奈當(dāng)時(shí)用wwm訓(xùn)練完提交結(jié)果時(shí),提交通道已經(jīng)關(guān)閉了,嗚嗚)
bert_model.ckpt:負(fù)責(zé)模型變量載入
vocab.txt:訓(xùn)練時(shí)中文文本采用的字典
bert_config.json:BERT在訓(xùn)練時(shí),可選調(diào)整的一些參數(shù)
2.2 數(shù)據(jù)準(zhǔn)備
1)將自己的數(shù)據(jù)集格式改成如下格式:第一列是標(biāo)簽,第二列是文本數(shù)據(jù),中間用tab隔開(若測(cè)試集沒有標(biāo)簽,只保留一列樣本數(shù)據(jù))。 分別將訓(xùn)練集、驗(yàn)證集、測(cè)試集文件名改為train.tsv、val.tsv、test.tsv。文件格式為UTF-8(無BOM)
2)新建data文件夾,存放這三個(gè)文件。
3)預(yù)訓(xùn)練模型解壓,存放到新建文件夾chinese中
2.3 代碼修改
我們需要對(duì)bert源碼中run_classifier.py進(jìn)行兩處修改
1)在run_classifier.py中添加我們的任務(wù)類
可以參照其他Processor類,添加自己的任務(wù)類
# 自定義Processor類
class MyProcessor(DataProcessor):
def __init__(self):
self.labels = ['Addictive Behavior',
'Address',
'Age',
'Alcohol Consumer',
'Allergy Intolerance',
'Bedtime',
'Blood Donation',
'Capacity',
'Compliance with Protocol',
'Consent',
'Data Accessible',
'Device',
'Diagnostic',
'Diet',
'Disabilities',
'Disease',
'Education',
'Encounter',
'Enrollment in other studies',
'Ethical Audit',
'Ethnicity',
'Exercise',
'Gender',
'Healthy',
'Laboratory Examinations',
'Life Expectancy',
'Literacy',
'Multiple',
'Neoplasm Status',
'Non-Neoplasm Disease Stage',
'Nursing',
'Oral related',
'Organ or Tissue Status',
'Pharmaceutical Substance or Drug',
'Pregnancy-related Activity',
'Receptor Status',
'Researcher Decision',
'Risk Assessment',
'Sexual related',
'Sign',
'Smoking Status',
'Special Patient Characteristic',
'Symptom',
'Therapy or Surgery']
def get_train_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "val.tsv")), "val")
def get_test_examples(self, data_dir):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
return self.labels
def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
if set_type == "test":
"""
因?yàn)槲业臏y(cè)試集中沒有標(biāo)簽,所以對(duì)test進(jìn)行單獨(dú)處理,
test的label值設(shè)為任意一標(biāo)簽(一定是存在的類標(biāo)簽,
不然predict時(shí)會(huì)keyError),如果測(cè)試集中有標(biāo)簽,就
不需要if了,統(tǒng)一處理即可。
"""
text_a = tokenization.convert_to_unicode(line[0])
label = "Address"
else:
text_a = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
2)修改processor字典
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"mytask": MyProcessor, # 將自己的Processor添加到字典
}
3 開工
3.1 配置訓(xùn)練腳本
創(chuàng)建并運(yùn)行run.sh這個(gè)文件
python run_classifier.py \
--data_dir=data \
--task_name=mytask \
--do_train=true \
--do_eval=true \
--vocab_file=chinese/vocab.txt \
--bert_config_file=chinese/bert_config.json \
--init_checkpoint=chinese/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=8 \
--learning_rate=2e-5 \
--num_train_epochs=3.0
--output_dir=out \
fine-tune需要一定的時(shí)間,我的訓(xùn)練集有兩萬條,驗(yàn)證集有八千條,GPU為2080Ti,需要20分鐘左右。如果顯存不夠大,記得適當(dāng)調(diào)整max_seq_length 和 train_batch_size
3.2 預(yù)測(cè)
創(chuàng)建并運(yùn)行test.sh(注:init_checkpoint為自己之前輸出模型地址)
python run_classifier.py \
--task_name=mytask \
--do_predict=true \
--data_dir=data \
--vocab_file=chinese/vocab.txt \
--bert_config_file=chinese/bert_config.json \
--init_checkpoint=out \
--max_seq_length=128 \
--output_dir=out
預(yù)測(cè)完會(huì)在out目錄下生成test_results.tsv。生成文件中,每一行對(duì)應(yīng)你訓(xùn)練集中的每一個(gè)樣本,每一列對(duì)應(yīng)的是每一類的概率(對(duì)應(yīng)之前自定義的label列表)。如第5行第8列表示第5個(gè)樣本是第8類的概率。
3.3 預(yù)測(cè)結(jié)果處理鄭州婦科醫(yī)院 http://www.zykdfkyy.com/
因?yàn)轭A(yù)測(cè)結(jié)果是概率,我們需要對(duì)其處理,選取每一行中的大值最為預(yù)測(cè)值,并轉(zhuǎn)換成對(duì)應(yīng)的真實(shí)標(biāo)簽。
data_dir = "C:\\test_results.tsv"
lable = ['Addictive Behavior',
'Address',
'Age',
'Alcohol Consumer',
'Allergy Intolerance',
'Bedtime',
'Blood Donation',
'Capacity',
'Compliance with Protocol',
'Consent',
'Data Accessible',
'Device',
'Diagnostic',
'Diet',
'Disabilities',
'Disease',
'Education',
'Encounter',
'Enrollment in other studies',
'Ethical Audit',
'Ethnicity',
'Exercise',
'Gender',
'Healthy',
'Laboratory Examinations',
'Life Expectancy',
'Literacy',
'Multiple',
'Neoplasm Status',
'Non-Neoplasm Disease Stage',
'Nursing',
'Oral related',
'Organ or Tissue Status',
'Pharmaceutical Substance or Drug',
'Pregnancy-related Activity',
'Receptor Status',
'Researcher Decision',
'Risk Assessment',
'Sexual related',
'Sign',
'Smoking Status',
'Special Patient Characteristic',
'Symptom',
'Therapy or Surgery']
# 用pandas讀取test_result.tsv,將標(biāo)簽設(shè)置為列名
data_df = pd.read_table(data_dir, sep="\t", names=lable, encoding="utf-8")
label_test = []
for i in range(data_df.shape[0]):
# 獲取一行中大值對(duì)應(yīng)的列名,追加到列表
label_test.append(data_df.loc[i, :].idxmax())
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)cdcxhl.cn,海內(nèi)外云服務(wù)器15元起步,三天無理由+7*72小時(shí)售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國(guó)服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡(jiǎn)單易用、服務(wù)可用性高、性價(jià)比高”等特點(diǎn)與優(yōu)勢(shì),專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場(chǎng)景需求。
文章名稱:用BERT進(jìn)行中文短文本分類-創(chuàng)新互聯(lián)
本文URL:http://chinadenli.net/article28/dhppcp.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供小程序開發(fā)、微信小程序、用戶體驗(yàn)、App設(shè)計(jì)、動(dòng)態(tài)網(wǎng)站、營(yíng)銷型網(wǎng)站建設(shè)
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容