建设网站的基本工作流程,网站建设人员求职信,做网站商城需要申请商标吗,网站数据抓取怎么做文章目录 1、准备用于训练的数据集2、处理数据集3、克隆代码4、运行代码5、将ckpt模型转为bin模型使其可在pytorch中运用 Bert官方仓库#xff1a;https://github.com/google-research/bert
1、准备用于训练的数据集
此处准备的是BBC news的数据集#xff0c;下载链接… 文章目录 1、准备用于训练的数据集2、处理数据集3、克隆代码4、运行代码5、将ckpt模型转为bin模型使其可在pytorch中运用 Bert官方仓库https://github.com/google-research/bert
1、准备用于训练的数据集
此处准备的是BBC news的数据集下载链接https://www.kaggle.com/datasets/gpreda/bbc-news 原数据集格式.csv
2、处理数据集
训练Bert时需要预处理数据将数据处理成https://github.com/google-research/bert/blob/master/sample_text.txt中所示格式如下所示 数据预处理代码参考
import pandas as pd# 读取BBC-news数据集
df pd.read_csv(../../bbc_news.csv)
# print(df[title])
l1 []
l2 []
cnt 0
for line in df[title]:l1.append(line)for line in df[description]:l2.append(line)
# cnt0
f open(test1.txt, w, encodingutf8)
for i in range(len(l1)):s l1[i] l2[i] \nf.write(s)# cnt1# if cnt10: break
f.close()
# print(l1)处理完后的BBC news数据集格式如下所示
3、克隆代码
使用git克隆仓库代码 http:
git clone https://github.com/google-research/bert.git或ssh:
git clone gitgithub.com:google-research/bert.git4、运行代码
先下载Bert模型BERT-Base, Uncased 该文件中有以下文件 运行代码 在Teminal中运行
python create_pretraining_data.py \--input_file./sample_text.txt数据集地址 \--output_file/tmp/tf_examples.tfrecord处理后数据集保存的位置 \--vocab_file$BERT_BASE_DIR/vocab.txtvocab.txt文件位置 \--do_lower_caseTrue \--max_seq_length128 \--max_predictions_per_seq20 \--masked_lm_prob0.15 \--random_seed12345 \--dupe_factor5训练模型
python run_pretraining.py \--input_file/tmp/tf_examples.tfrecord处理后数据集保存的位置 \--output_dir/tmp/pretraining_output训练后模型保存位置 \--do_trainTrue \--do_evalTrue \--bert_config_file$BERT_BASE_DIR/bert_config.jsonbert_config.json文件位置 \--init_checkpoint$BERT_BASE_DIR/bert_model.ckpt如果要从头开始的预训练则去掉这行 \--train_batch_size32 \--max_seq_length128 \--max_predictions_per_seq20 \--num_train_steps20 \--num_warmup_steps10 \--learning_rate2e-5训练完成后模型输出示例
***** Eval results *****global_step 20loss 0.0979674masked_lm_accuracy 0.985479masked_lm_loss 0.0979328next_sentence_accuracy 1.0next_sentence_loss 3.45724e-05要注意应该能够在至少具有 12GB RAM 的 GPU 上运行不然会报错显存不足。 使用未标注数据训练BERT
5、将ckpt模型转为bin模型使其可在pytorch中运用
上一步训练好后准备好训练出来的model.ckpt-20.index文件和Bert模型中的bert_config.json文件
创建python文件convert_bert_original_tf_checkpoint_to_pytorch.py
# codingutf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Convert BERT checkpoint.import argparseimport torchfrom transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logginglogging.set_verbosity_info()def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):# Initialise PyTorch modelconfig BertConfig.from_json_file(bert_config_file)print(Building PyTorch model from configuration: {}.format(str(config)))model BertForPreTraining(config)# Load weights from tf checkpointload_tf_weights_in_bert(model, config, tf_checkpoint_path)# Save pytorch-modelprint(Save PyTorch model to {}.format(pytorch_dump_path))torch.save(model.state_dict(), pytorch_dump_path)if __name__ __main__:parser argparse.ArgumentParser()# Required parametersparser.add_argument(--tf_checkpoint_path, defaultNone, typestr, requiredTrue, helpPath to the TensorFlow checkpoint path.)parser.add_argument(--bert_config_file,defaultNone,typestr,requiredTrue,helpThe config json file corresponding to the pre-trained BERT model. \nThis specifies the model architecture.,)parser.add_argument(--pytorch_dump_path, defaultNone, typestr, requiredTrue, helpPath to the output PyTorch model.)args parser.parse_args()convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
在Terminal中运行以下命令
python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index.ckpt.index文件位置 \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.jsonbert_config.json文件位置 \
--pytorch_dump_path Models/chinese_L-12_H-768_A-12/pytorch_model.bin输出的.bin模型文件位置
以上命令最好在一行中运行
python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json --pytorch_dump_path pytorch_model.bin然后就可以得到bin文件了
【BERT for Tensorflow】本地ckpt文件的BERT使用