# code2.training.py 422 行添加
score = metrics_dict['pr/recall']
return score
# code2.training.py 457 行添加 saveModel方法
def saveModel(self, type_str, epoch_ndx, isBest=False):
file_path = os.path.join(
'data-unversioned',
'part',
'models',
self.cli_args.tb_prefix,
'{}_{}_{}.{}.state'.format(
type_str,
self.time_str,
self.cli_args.comment,
self.totalTrainingSamples_count,
)
)
os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
model = self.model
if isinstance(model, torch.nn.DataParallel):
model = model.module
state = {
'sys_argv': sys.argv,
'time': str(datetime.datetime.now()),
'model_state': model.state_dict(),
'model_name': type(model).__name__,
'optimizer_state' : self.optimizer.state_dict(),
'optimizer_name': type(self.optimizer).__name__,
'epoch': epoch_ndx,
'totalTrainingSamples_count': self.totalTrainingSamples_count,
}
torch.save(state, file_path)
log.info("Saved model params to {}".format(file_path))
if isBest:
best_path = os.path.join(
'data-unversioned', 'part', 'models',
self.cli_args.tb_prefix,
f'{type_str}_{self.time_str}_{self.cli_args.comment}.best.state')
shutil.copyfile(file_path, best_path)
log.info("Saved model params to {}".format(best_path))
with open(file_path, 'rb') as f:
log.info("SHA1: " + hashlib.sha1(f.read()).hexdigest())# code2.training.py 212 行添加
if epoch_ndx == 1 or epoch_ndx % 5 == 0:
# if validation is wanted
valMetrics_t = self.doValidation(epoch_ndx, val_dl)
score = self.logMetrics(epoch_ndx, 'val', valMetrics_t)
best_score = max(score, best_score)
self.saveModel('cls', epoch_ndx, score == best_score)