用bert训练模型并转换为pb格式

    技术2022-07-12  75

    具体代码在github: https://github.com/danan0755/Bert_Classifier/blob/master/Bert_Train.py

    def serving_input_fn(): # 保存模型为SaveModel格式 # 采用最原始的feature方式,输入是feature Tensors。 # 如果采用build_parsing_serving_input_receiver_fn,则输入是tf.Examples df = pd.read_csv(FLAGS.data_dir, delimiter="\t", names=['labels', 'text'], header=None) dense_units = len(df.labels.unique()) label_ids = tf.placeholder(tf.int32, [None, dense_units], name='label_ids') input_ids = tf.placeholder(tf.int32, [None, 128], name='input_ids') input_mask = tf.placeholder(tf.int32, [None, 128], name='input_mask') segment_ids = tf.placeholder(tf.int32, [None, 128], name='segment_ids') input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 'label_ids': label_ids, 'input_ids': input_ids, 'input_mask': input_mask, 'segment_ids': segment_ids, })() return input_fn
    Processed: 0.011, SQL: 9