NLP——如何批量加载数据

                                            NLP——如何批量加载数据

 

问题背景:利用Bert、Albert、Roberta或腾讯词向量等预训练模型,去微调或者特征集成各类NLP子任务,并转存为pb模型后,如何进行批量预测,以提升性能,缩短耗时呢?当然是batch预测了!为了方便以后使用,基于苏神的代码,我封装成了一个简单的类。下面的代码虽然短,但需要读者朋友了解以下资料:

 

资源:

效果:

  • 与单条文本预测相比,每个batch的预测速度更快,约k*batch_size倍(k=1或2)

原因

  • 批量预测,相当于只用计算N/batch_size次矩阵乘法(或点积运算),而文本输入单条预测存在两部分耗时:T(数据预处理) + T(向量乘法)

 

代码示例(直接复制过去,运行即可,懂者自懂):

import os
try:
    os.system("pip install -i https://pypi.tuna.tsinghua.edu.cn/simple loguru")
    os.system("pip install -i https://pypi.tuna.tsinghua.edu.cn/simple bert4keras")
except:
    pass

import numpy as np
from loguru import logger
from bert4keras.backend import set_gelu
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open

set_gelu('tanh')  # 切换gelu版本



def load_data(filename):
    D = []
    with open(filename, encoding='utf-8') as f:
        for l in f:
            text1, text2, label = l.strip().split('\t')
            D.append((text1, text2, int(label)))
    return D


# 数据生成器类
class data_generator(DataGenerator):
    def __init__(self, maxlen, dict_path, data, batch_size, buffer_size):
        # 子类继承父类, 并进行初始化
        super(data_generator, self).__init__(data=data, batch_size=batch_size, buffer_size=buffer_size)
        self.maxlen = maxlen
        self.tokenizer = Tokenizer(dict_path, do_lower_case=True)

    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids, batch_labels = [], [], []
        for is_end, (text1, text2, label) in self.sample(random):
            token_ids, segment_ids = self.tokenizer.encode(text1, text2, maxlen=self.maxlen)
            batch_token_ids.append(token_ids)
            batch_segment_ids.append(segment_ids)
            batch_labels.append([label])
            if len(batch_token_ids) == self.batch_size or is_end:
                batch_token_ids = sequence_padding(batch_token_ids)
                batch_segment_ids = sequence_padding(batch_segment_ids)
                batch_labels = sequence_padding(batch_labels)
                yield [batch_token_ids, batch_segment_ids], batch_labels
                batch_token_ids, batch_segment_ids, batch_labels = [], [], []


if __name__ == "__main__":
    data_generator = data_generator(maxlen=128,
                                    dict_path="./bert/chinese_L-12_H-768_A-12/vocab.txt",
                                    data=load_data("./test.txt"),
                                    batch_size=32,
                                    buffer_size=None)

    for x_true, y_true in data_generator:
        # 输出每个batch中的数据
        for idx in range(len(x_true[0])):
            logger.info("第{0}条文本:".format(idx))
            logger.info("word2id: {0}".format(str(list(x_true[0][idx]))))
            logger.info("mask: {0}".format(str(list(x_true[1][idx]))))
            logger.info("label: {0}\n\n".format(str(list(y_true[idx]))))

运行结果:

2020-10-14 22:28:57.416 | INFO     | __main__:<module>:73 - 第0条文本:
2020-10-14 22:28:57.417 | INFO     | __main__:<module>:74 - word2id: [101, 6443, 3300, 4312, 676, 6821, 2476, 7770, 3926, 4638, 102, 6821, 2476, 7770, 3926, 1745, 8024, 6443, 3300, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO     | __main__:<module>:76 - label: [0]

2020-10-14 22:28:57.417 | INFO     | __main__:<module>:73 - 第1条文本:
2020-10-14 22:28:57.417 | INFO     | __main__:<module>:74 - word2id: [101, 5739, 7413, 5468, 4673, 784, 720, 5739, 7413, 3297, 1962, 102, 5739, 7413, 5468, 4673, 3297, 1962, 5739, 7413, 3221, 784, 720, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.417 | INFO     | __main__:<module>:76 - label: [1]

2020-10-14 22:28:57.418 | INFO     | __main__:<module>:73 - 第2条文本:
2020-10-14 22:28:57.418 | INFO     | __main__:<module>:74 - word2id: [101, 6821, 3221, 784, 720, 2692, 2590, 8024, 6158, 6701, 5381, 1408, 102, 2769, 738, 3221, 7004, 749, 8024, 6821, 3221, 784, 720, 2692, 2590, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.418 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.418 | INFO     | __main__:<module>:76 - label: [0]

2020-10-14 22:28:57.418 | INFO     | __main__:<module>:73 - 第3条文本:
2020-10-14 22:28:57.419 | INFO     | __main__:<module>:74 - word2id: [101, 4385, 1762, 3300, 784, 720, 1220, 4514, 4275, 1962, 4692, 1450, 8043, 102, 4385, 1762, 3300, 784, 720, 1962, 4692, 4638, 1220, 4514, 4275, 1408, 8043, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.419 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.419 | INFO     | __main__:<module>:76 - label: [1]

2020-10-14 22:28:57.419 | INFO     | __main__:<module>:73 - 第4条文本:
2020-10-14 22:28:57.419 | INFO     | __main__:<module>:74 - word2id: [101, 6435, 7309, 3253, 6809, 4510, 2094, 1322, 4385, 1762, 4638, 2339, 6598, 2521, 6878, 2582, 720, 3416, 6206, 3724, 3300, 1525, 763, 102, 676, 3215, 4510, 2094, 1322, 2339, 6598, 2521, 6878, 2582, 720, 3416, 1557, 102]
2020-10-14 22:28:57.419 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
2020-10-14 22:28:57.419 | INFO     | __main__:<module>:76 - label: [0]

2020-10-14 22:28:57.419 | INFO     | __main__:<module>:73 - 第5条文本:
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:74 - word2id: [101, 3152, 4995, 4696, 4638, 4263, 2001, 5013, 1408, 102, 2001, 5013, 4696, 4638, 6158, 3152, 4995, 2397, 749, 1408, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:76 - label: [0]

2020-10-14 22:28:57.420 | INFO     | __main__:<module>:73 - 第6条文本:
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:74 - word2id: [101, 6843, 5632, 2346, 976, 4638, 7318, 6057, 784, 720, 4495, 3189, 4851, 4289, 1962, 102, 6843, 7318, 6057, 784, 720, 4495, 3189, 4851, 4289, 1962, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:76 - label: [1]

2020-10-14 22:28:57.420 | INFO     | __main__:<module>:73 - 第7条文本:
2020-10-14 22:28:57.420 | INFO     | __main__:<module>:74 - word2id: [101, 6818, 3309, 677, 3216, 4638, 4510, 2512, 102, 6818, 3309, 677, 3216, 4638, 4510, 2512, 3300, 1525, 763, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO     | __main__:<module>:76 - label: [1]

2020-10-14 22:28:57.421 | INFO     | __main__:<module>:73 - 第8条文本:
2020-10-14 22:28:57.421 | INFO     | __main__:<module>:74 - word2id: [101, 3724, 5739, 7413, 5468, 4673, 1920, 4868, 2372, 8043, 102, 5739, 7413, 5468, 4673, 8024, 3724, 1920, 4868, 2372, 172, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.421 | INFO     | __main__:<module>:76 - label: [1]

2020-10-14 22:28:57.422 | INFO     | __main__:<module>:73 - 第9条文本:
2020-10-14 22:28:57.422 | INFO     | __main__:<module>:74 - word2id: [101, 1963, 1217, 677, 784, 720, 6956, 7674, 102, 5314, 691, 1217, 677, 6956, 7674, 3221, 784, 720, 2099, 8043, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.422 | INFO     | __main__:<module>:75 - mask: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
2020-10-14 22:28:57.422 | INFO     | __main__:<module>:76 - label: [0]

说明:

  • 代码中的"vocab.txt",做NLP的都知道
  • 代码中的"test.txt"文件来源于LCQMC数据中test.txt中的前10条,内容如下(仅供demo使用):
谁有狂三这张高清的	这张高清图,谁有	0
英雄联盟什么英雄最好	英雄联盟最好英雄是什么	1
这是什么意思,被蹭网吗	我也是醉了,这是什么意思	0
现在有什么动画片好看呢?	现在有什么好看的动画片吗?	1
请问晶达电子厂现在的工资待遇怎么样要求有哪些	三星电子厂工资待遇怎么样啊	0
文章真的爱姚笛吗	姚笛真的被文章干了吗	0
送自己做的闺蜜什么生日礼物好	送闺蜜什么生日礼物好	1
近期上映的电影	近期上映的电影有哪些	1
求英雄联盟大神带?	英雄联盟,求大神带~	1
如加上什么部首	给东加上部首是什么字?	0

如果不是匹配任务,而是分类任务(格式:文本\t标签),需要更改data_generator类中的2行代码,如下:

for is_end, (text1, label) in self.sample(random):
    token_ids, segment_ids = self.tokenizer.encode(text1, maxlen=self.maxlen)

其他任务类似处理,结束撒花🎉🎉🎉。分享一部口碑不错的动漫:《灵笼》

 

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页