FastBERT的创新点很容易理解,就是在每层Transformer后都去预测样本标签,如果某样本预测结果的置信度很高,就不用继续计算了。论文把这个逻辑称为样本自适应机制(Sample-wise adaptive mechanism),就是自适应调整每个样本的计算量,容易的样本通过一两层就可以预测出来,较难的样本则需要走完全程。
那么问题来了,用什么去预测中间层的结果呢?作者的解决方案是给每层后面接一个分类器,毕竟分类器比Transformer需要的成本小多了:
注:FLOPs (floating point operations)是Tensorflow中提供的浮点数计算量统计
于是模型的整体结构就呼之欲出了:
作者将原BERT模型称为主干(Backbone),每个分类器称为分支(Branch)。
要注意的是,这里的分支Classifier都是最后一层的分类器蒸馏来的,作者将这称为自蒸馏(Self-distillation)。就是在预训练和精调阶段都只更新主干参数,精调完后freeze主干参数,用分支分类器(图中的student)蒸馏主干分类器(图中的teacher)的概率分布。
之所以叫自蒸馏,是因为之前的蒸馏都是用两个模型去做,一个模型学习另一个模型的知识,而FastBERT是自己(分支)蒸馏自己(主干)的知识。值得注意的是,蒸馏时需要freeze主干部分,保证pretrain和finetune阶段学习的知识不被影响,仅用brach 来尽可能的拟合teacher的分布。
那为什么不直接用标注数据训分支分类器呢?因为直接训效果不好呗(摊手~下面是作者在消融实验给出的结果:
可以看到,非蒸馏的结果没有蒸馏要好。个人认为是合理的,因为这两种方式在精调阶段的目标不一样。非自蒸馏是在精调阶段训练所有分类器,目标函数有所改变,迫使前几层编码器抽取更多的任务feature。但BERT强大的能力与网络深度的相关性很大,所以过早地判断不一定准确,致使效果下降。
同时,使用自蒸馏还有一点重要的好处,就是不再依赖于标注数据。蒸馏的效果可以通过源源不断的无标签数据来提升。
了解模型结构之后,训练与推理也就很自然了。只比普通的BERT模型多了自蒸馏这个步骤:
Pre-training:同BERT系模型是一样的,网上那么多开源的模型也可以随意拿来~Fine-tuning for Backbone:主干精调,也就是给BERT最后一层加上分类器,用任务数据训练,这里也用不到分支分类器,可以尽情地优化Self-distillation for branch:分支自蒸馏,用无标签任务数据就可以,将主干分类器预测的概率分布蒸馏给分支分类器。这里使用KL散度衡量分布距离,loss是所有分支分类器与主干分类器的KL散度之和Adaptive inference:自适应推理,及根据分支分类器的结果对样本进行层层过滤,简单的直接给结果,困难的继续预测。这里作者定义了新的不确定性指标,用预测结果的熵来衡量,熵越大则不确定性越大:
[公式] 效果
对于每层分类结果,作者用“Speed”代表不确定性的阈值,和推理速度是正比关系。因为阈值越小 => 不确定性越小 => 过滤的样本越少 => 推理速度越慢。
模型最终在12个数据集(6个中文的和6个英文的)上的表现还是很好的:
可以看到,在Speed=0.2时速度可以提升1-10倍,且精度下降全部在0.11个点之内,甚至部分任务上还有细微提升。相比之下HuggingFace的DistillBERT的波动就比较剧烈了,6层模型速度只提升2倍,但精度下降最高会达到7个点。
模型和代码下载链接:https://pan.baidu.com/s/1uzAm-M6dRaR2X-jFQbknbg 提取码:go67
# -*- encoding:utf-8 -*- """ This script provides an exmaple to the fine-tuning and self-distillation peocess of the FastBERT. """ import os, sys import torch import json import random import argparse import collections import torch.nn as nn from uer.utils.vocab import Vocab from uer.utils.constants import * from uer.utils.tokenizer import * from uer.model_builder import build_model from uer.utils.optimizers import * from uer.utils.config import load_hyperparam from uer.utils.seed import set_seed from uer.model_saver import save_model from uer.model_loader import load_model from uer.layers.multi_headed_attn import MultiHeadedAttention import numpy as np import time from thop import profile torch.set_num_threads(1) def normal_shannon_entropy(p, labels_num): entropy = torch.distributions.Categorical(probs=p).entropy() normal = -np.log(1.0/labels_num) return entropy / normal class Classifier(nn.Module): def __init__(self, args, input_size, labels_num): super(Classifier, self).__init__() self.input_size = input_size self.cla_hidden_size = 128 self.cla_heads_num = 2 self.labels_num = labels_num self.pooling = args.pooling self.output_layer_0 = nn.Linear(input_size, self.cla_hidden_size) self.self_atten = MultiHeadedAttention(self.cla_hidden_size, self.cla_heads_num, args.dropout) self.output_layer_1 = nn.Linear(self.cla_hidden_size, self.cla_hidden_size) self.output_layer_2 = nn.Linear(self.cla_hidden_size, labels_num) def forward(self, hidden, mask): hidden = torch.tanh(self.output_layer_0(hidden)) hidden = self.self_atten(hidden, hidden, hidden, mask) if self.pooling == "mean": hidden = torch.mean(hidden, dim=-1) elif self.pooling == "max": hidden = torch.max(hidden, dim=1)[0] elif self.pooling == "last": hidden = hidden[:, -1, :] else: hidden = hidden[:, 0, :] output_1 = torch.tanh(self.output_layer_1(hidden)) logits = self.output_layer_2(output_1) return logits class FastBertClassifier(nn.Module): def __init__(self, args, model): super(FastBertClassifier, self).__init__() self.embedding = model.embedding self.encoder = model.encoder self.labels_num = args.labels_num self.classifiers = nn.ModuleList([ Classifier(args, args.hidden_size, self.labels_num) \ for i in range(self.encoder.layers_num) ]) self.softmax = nn.LogSoftmax(dim=-1) self.criterion = nn.NLLLoss() self.soft_criterion = nn.KLDivLoss(reduction='batchmean') self.threshold = args.speed def forward(self, src, label, mask, fast=True): """ Args: src: [batch_size x seq_length] label: [batch_size] mask: [batch_size x seq_length] """ # Embedding. emb = self.embedding(src, mask) # Encoder. seq_length = emb.size(1) mask = (mask > 0). \ unsqueeze(1). \ repeat(1, seq_length, 1). \ unsqueeze(1) mask = mask.float() mask = (1.0 - mask) * -10000.0 if self.training: if label is not None: # training main part of the model hidden = emb for i in range(self.encoder.layers_num): hidden = self.encoder.transformer[i](hidden, mask) logits = self.classifiers[-1](hidden, mask) loss = self.criterion(self.softmax(logits.view(-1, self.labels_num)), label.view(-1)) return loss, logits else: # distillate the subclassifiers loss, hidden, hidden_list = 0, emb, [] with torch.no_grad(): for i in range(self.encoder.layers_num): hidden = self.encoder.transformer[i](hidden, mask) hidden_list.append(hidden) teacher_logits = self.classifiers[-1](hidden_list[-1], mask).view(-1, self.labels_num) teacher_probs = nn.functional.softmax(teacher_logits, dim=1) loss = 0 for i in range(self.encoder.layers_num - 1): student_logits = self.classifiers[i](hidden_list[i], mask).view(-1, self.labels_num) loss += self.soft_criterion(self.softmax(student_logits), teacher_probs) return loss, teacher_logits else: # inference if fast: # fast mode hidden = emb # (batch_size, seq_len, emb_size) batch_size = hidden.size(0) logits = torch.zeros(batch_size, self.labels_num, dtype=hidden.dtype, device=hidden.device) abs_diff_idxs = torch.arange(0, batch_size, dtype=torch.long, device=hidden.device) for i in range(self.encoder.layers_num): hidden = self.encoder.transformer[i](hidden, mask) logits_this_layer = self.classifiers[i](hidden, mask) # (batch_size, labels_num) logits[abs_diff_idxs] = logits_this_layer # filter easy sample abs_diff_idxs, rel_diff_idxs = self._difficult_samples_idxs(abs_diff_idxs, logits_this_layer) hidden = hidden[rel_diff_idxs, :, :] mask = mask[rel_diff_idxs, :, :] if len(abs_diff_idxs) == 0: break return None, logits else: # normal mode hidden = emb for i in range(self.encoder.layers_num): hidden = self.encoder.transformer[i](hidden, mask) logits = self.classifiers[-1](hidden, mask) return None, logits def _difficult_samples_idxs(self, idxs, logits): # logits: (batch_size, labels_num) probs = nn.Softmax(dim=1)(logits) entropys = normal_shannon_entropy(probs, self.labels_num) # torch.nonzero() is very time-consuming on GPU # Please see https://github.com/pytorch/pytorch/issues/14848 # If anyone can optimize this operation, please contact me, thank you! rel_diff_idxs = (entropys > self.threshold).nonzero().view(-1) abs_diff_idxs = torch.tensor([idxs[i] for i in rel_diff_idxs], device=logits.device) return abs_diff_idxs, rel_diff_idxs def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) # Path options. parser.add_argument("--pretrained_model_path", default="./models/Chinese_base_model.bin", type=str, help="Path of the pretrained model.") parser.add_argument("--output_model_path", default="./models/fastbert.bin", type=str, help="Path of the output model.") parser.add_argument("--vocab_path", type=str, required=False,default="./models/google_zh_vocab.txt", help="Path of the vocabulary file.") parser.add_argument("--train_path", type=str, required=False, default="./datasets/douban_book_review/train.tsv", help="Path of the trainset.") parser.add_argument("--dev_path", type=str, required=False,default="./datasets/douban_book_review/dev.tsv", help="Path of the devset.") parser.add_argument("--test_path", type=str,default="./datasets/douban_book_review/test.tsv", help="Path of the testset.") parser.add_argument("--config_path", default="./models/bert_base_config.json", type=str, help="Path of the config file.") # Model options. parser.add_argument("--batch_size", type=int, default=32, help="Batch size.") parser.add_argument("--seq_length", type=int, default=128, help="Sequence length.") parser.add_argument("--embedding", choices=["bert", "word"], default="bert", help="Emebdding type.") parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \ "cnn", "gatedcnn", "attn", \ "rcnn", "crnn", "gpt", "bilstm"], \ default="bert", help="Encoder type.") parser.add_argument("--bidirectional", action="store_true", help="Specific to recurrent model.") parser.add_argument("--pooling", choices=["mean", "max", "first", "last"], default="first", help="Pooling type.") # Subword options. parser.add_argument("--subword_type", choices=["none", "char"], default="none", help="Subword feature type.") parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt", help="Path of the subword vocabulary file.") parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg", help="Subencoder type.") parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.") # Tokenizer options. parser.add_argument("--tokenizer", choices=["bert", "char", "space"], default="bert", help="Specify the tokenizer." "Original Google BERT uses bert tokenizer on Chinese corpus." "Char tokenizer segments sentences into characters." "Space tokenizer segments sentences into words according to space." ) # Optimizer options. parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate.") parser.add_argument("--warmup", type=float, default=0.1, help="Warm up value.") # Training options. parser.add_argument("--dropout", type=float, default=0.5, help="Dropout.") parser.add_argument("--epochs_num", type=int, default=3, help="Number of epochs.") parser.add_argument("--distill_epochs_num", type=int, default=5, help="Number of distillation epochs.") parser.add_argument("--report_steps", type=int, default=100, help="Specific steps to print prompt.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") # Evaluation options. parser.add_argument("--mean_reciprocal_rank", action="store_true", help="Evaluation metrics for DBQA dataset.") parser.add_argument("--fast_mode", dest='fast_mode', action='store_true', help="Whether turn on fast mode") parser.add_argument("--speed", type=float, default=0.5, help="Threshold of Uncertainty, i.e., the Speed in paper.") args = parser.parse_args() # Load the hyperparameters from the config file. args = load_hyperparam(args) set_seed(args.seed) # Count the number of labels. labels_set = set() columns = {} with open(args.train_path, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): try: line = line.strip().split("\t") if line_id == 0: for i, column_name in enumerate(line): columns[column_name] = i continue label = int(line[columns["label"]]) labels_set.add(label) except: pass args.labels_num = len(labels_set) # Load vocabulary. vocab = Vocab() vocab.load(args.vocab_path) args.vocab = vocab # Build bert model. # A pseudo target is added. args.target = "bert" model = build_model(args) # Load or initialize parameters. if args.pretrained_model_path is not None: # Initialize with pretrained model. model.load_state_dict(torch.load(args.pretrained_model_path), strict=False) else: # Initialize with normal distribution. for n, p in list(model.named_parameters()): if 'gamma' not in n and 'beta' not in n: p.data.normal_(0, 0.02) # Build classification model. model = FastBertClassifier(args, model) # For simplicity, we use DataParallel wrapper to use multiple GPUs. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = nn.DataParallel(model) model = model.to(device) # Datset loader. def batch_loader(batch_size, input_ids, label_ids, mask_ids): instances_num = input_ids.size()[0] for i in range(instances_num // batch_size): input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :] label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size] mask_ids_batch = mask_ids[i*batch_size: (i+1)*batch_size, :] yield input_ids_batch, label_ids_batch, mask_ids_batch if instances_num > instances_num // batch_size * batch_size: input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :] label_ids_batch = label_ids[instances_num//batch_size*batch_size:] mask_ids_batch = mask_ids[instances_num//batch_size*batch_size:, :] yield input_ids_batch, label_ids_batch, mask_ids_batch # Build tokenizer. tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args) # Read dataset. def read_dataset(path): dataset = [] with open(path, mode="r", encoding="utf-8") as f: for line_id, line in enumerate(f): if line_id == 0: continue try: line = line.strip().split('\t') if len(line) == 2: label = int(line[columns["label"]]) text = line[columns["text_a"]] tokens = [vocab.get(t) for t in tokenizer.tokenize(text)] tokens = [CLS_ID] + tokens mask = [1] * len(tokens) if len(tokens) > args.seq_length: tokens = tokens[:args.seq_length] mask = mask[:args.seq_length] while len(tokens) < args.seq_length: tokens.append(0) mask.append(0) dataset.append((tokens, label, mask)) elif len(line) == 3: # For sentence pair input. label = int(line[columns["label"]]) text_a, text_b = line[columns["text_a"]], line[columns["text_b"]] tokens_a = [vocab.get(t) for t in tokenizer.tokenize(text_a)] tokens_a = [CLS_ID] + tokens_a + [SEP_ID] tokens_b = [vocab.get(t) for t in tokenizer.tokenize(text_b)] tokens_b = tokens_b + [SEP_ID] tokens = tokens_a + tokens_b mask = [1] * len(tokens_a) + [2] * len(tokens_b) if len(tokens) > args.seq_length: tokens = tokens[:args.seq_length] mask = mask[:args.seq_length] while len(tokens) < args.seq_length: tokens.append(0) mask.append(0) dataset.append((tokens, label, mask)) elif len(line) == 4: # For dbqa input. qid=int(line[columns["qid"]]) label = int(line[columns["label"]]) text_a, text_b = line[columns["text_a"]], line[columns["text_b"]] tokens_a = [vocab.get(t) for t in tokenizer.tokenize(text_a)] tokens_a = [CLS_ID] + tokens_a + [SEP_ID] tokens_b = [vocab.get(t) for t in tokenizer.tokenize(text_b)] tokens_b = tokens_b + [SEP_ID] tokens = tokens_a + tokens_b mask = [1] * len(tokens_a) + [2] * len(tokens_b) if len(tokens) > args.seq_length: tokens = tokens[:args.seq_length] mask = mask[:args.seq_length] while len(tokens) < args.seq_length: tokens.append(0) mask.append(0) dataset.append((tokens, label, mask, qid)) else: pass except: pass return dataset # Evaluation function. def evaluate(args, is_test, fast_mode=False): if is_test: dataset = read_dataset(args.test_path) else: dataset = read_dataset(args.dev_path) input_ids = torch.LongTensor([sample[0] for sample in dataset]) label_ids = torch.LongTensor([sample[1] for sample in dataset]) mask_ids = torch.LongTensor([sample[2] for sample in dataset]) batch_size = 1 instances_num = input_ids.size()[0] print("The number of evaluation instances: ", instances_num) print("Fast mode: ", fast_mode) correct = 0 # Confusion matrix. confusion = torch.zeros(args.labels_num, args.labels_num, dtype=torch.long) model.eval() if not args.mean_reciprocal_rank: total_flops, model_params_num = 0, 0 for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)): input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) with torch.no_grad(): # Get FLOPs at this batch inputs = (input_ids_batch, label_ids_batch, mask_ids_batch, fast_mode) flops, params = profile(model, inputs, verbose=False) total_flops += flops model_params_num = params # inference loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch, fast=fast_mode) logits = nn.Softmax(dim=1)(logits) pred = torch.argmax(logits, dim=1) gold = label_ids_batch for j in range(pred.size()[0]): confusion[pred[j], gold[j]] += 1 correct += torch.sum(pred == gold).item() print("Number of model parameters: {}".format(model_params_num)) print("FLOPs per sample in average: {}".format(total_flops / float(instances_num))) if is_test: print("Confusion matrix:") print(confusion) print("Report precision, recall, and f1:") for i in range(confusion.size()[0]): # p = confusion[i,i].item()/confusion[i,:].sum().item() r = confusion[i,i].item()/confusion[:,i].sum().item() # f1 = 2*p*r / (p+r) if is_test: print("Label {}: {:.3f}".format(i,r)) # print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i,p,r,f1)) print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(correct/len(dataset), correct, len(dataset))) return correct/len(dataset) else: for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)): input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) with torch.no_grad(): loss, logits = model(input_ids_batch, label_ids_batch, mask_ids_batch) logits = nn.Softmax(dim=1)(logits) if i == 0: logits_all=logits if i >= 1: logits_all=torch.cat((logits_all,logits),0) order = -1 gold = [] for i in range(len(dataset)): qid = dataset[i][3] label = dataset[i][1] if qid == order: j += 1 if label == 1: gold.append((qid,j)) else: order = qid j = 0 if label == 1: gold.append((qid,j)) label_order = [] order = -1 for i in range(len(gold)): if gold[i][0] == order: templist.append(gold[i][1]) elif gold[i][0] != order: order=gold[i][0] if i > 0: label_order.append(templist) templist = [] templist.append(gold[i][1]) label_order.append(templist) order = -1 score_list = [] for i in range(len(logits_all)): score = float(logits_all[i][1]) qid=int(dataset[i][3]) if qid == order: templist.append(score) else: order = qid if i > 0: score_list.append(templist) templist = [] templist.append(score) score_list.append(templist) rank = [] pred = [] for i in range(len(score_list)): if len(label_order[i])==1: if label_order[i][0] < len(score_list[i]): true_score = score_list[i][label_order[i][0]] score_list[i].sort(reverse=True) for j in range(len(score_list[i])): if score_list[i][j] == true_score: rank.append(1 / (j + 1)) else: rank.append(0) else: true_rank = len(score_list[i]) for k in range(len(label_order[i])): if label_order[i][k] < len(score_list[i]): true_score = score_list[i][label_order[i][k]] temp = sorted(score_list[i],reverse=True) for j in range(len(temp)): if temp[j] == true_score: if j < true_rank: true_rank = j if true_rank < len(score_list[i]): rank.append(1 / (true_rank + 1)) else: rank.append(0) MRR = sum(rank) / len(rank) print("Mean Reciprocal Rank: {:.4f}".format(MRR)) return MRR # Training phase. print("Start training.") trainset = read_dataset(args.train_path) random.shuffle(trainset) instances_num = len(trainset) batch_size = args.batch_size input_ids = torch.LongTensor([example[0] for example in trainset]) label_ids = torch.LongTensor([example[1] for example in trainset]) mask_ids = torch.LongTensor([example[2] for example in trainset]) train_steps = int(instances_num * args.epochs_num / batch_size) + 1 print("Batch size: ", batch_size) print("The number of training instances:", instances_num) param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) scheduler = WarmupLinearSchedule(optimizer, warmup_steps=train_steps*args.warmup, t_total=train_steps) # traning main part of model print("Start fine-tuning the backbone of the model.") total_loss = 0. result = 0.0 best_result = 0.0 for epoch in range(1, args.epochs_num+1): model.train() for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)): model.zero_grad() input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) loss, _ = model(input_ids_batch, label_ids_batch, mask_ids_batch) # training if torch.cuda.device_count() > 1: loss = torch.mean(loss) total_loss += loss.item() if (i + 1) % args.report_steps == 0: print("Epoch id: {}, backbone fine-tuning steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps)) total_loss = 0. loss.backward() optimizer.step() scheduler.step() result = evaluate(args, False, False) if result > best_result: best_result = result save_model(model, args.output_model_path) else: continue # Evaluation phase. if args.test_path is not None: print("Test set evaluation after bakbone fine-tuning.") model = load_model(model, args.output_model_path) print("Test on normal model") evaluate(args, True, False) if args.fast_mode: print("Test on Fast mode") evaluate(args, True, args.fast_mode) # Distillate subclassifiers print("Start self-distillation for student-classifiers.") param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'gamma', 'beta'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate*10, correct_bias=False) scheduler = WarmupLinearSchedule(optimizer, warmup_steps=train_steps*args.warmup, t_total=train_steps) model = load_model(model, args.output_model_path) total_loss = 0. result = 0.0 best_result = 0.0 for epoch in range(1, args.distill_epochs_num+1): model.train() for i, (input_ids_batch, label_ids_batch, mask_ids_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, mask_ids)): model.zero_grad() input_ids_batch = input_ids_batch.to(device) label_ids_batch = label_ids_batch.to(device) mask_ids_batch = mask_ids_batch.to(device) loss, _ = model(input_ids_batch, None, mask_ids_batch) # distillation if torch.cuda.device_count() > 1: loss = torch.mean(loss) total_loss += loss.item() if (i + 1) % args.report_steps == 0: print("Epoch id: {}, self-distillation steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps)) total_loss = 0. loss.backward() optimizer.step() scheduler.step() result = evaluate(args, False, args.fast_mode) save_model(model, args.output_model_path) # Evaluation phase. if args.test_path is not None: print("Test set evaluation after self-distillation.") model = load_model(model, args.output_model_path) evaluate(args, True, args.fast_mode) if __name__ == "__main__": main()