同学你好,input_fn应该是一个dataset:
predicted_value = linear_estimator.predict(
input_fn = lambda : make_dataset(eval_df, y_eval, epochs = 1, shuffle = False))
counter = 0
for i in predicted_value:
counter += 1
print(i)
if counter > 10:
break
打印出结果:
{'logits': array([0.68644077], dtype=float32), 'logistic': array([0.66517466], dtype=float32), 'probabilities': array([0.3348253 , 0.66517466], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([-0.33140332], dtype=float32), 'logistic': array([0.41789922], dtype=float32), 'probabilities': array([0.5821008 , 0.41789922], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.76469857], dtype=float32), 'logistic': array([0.682373], dtype=float32), 'probabilities': array([0.31762704, 0.682373 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([-1.6867031], dtype=float32), 'logistic': array([0.1562099], dtype=float32), 'probabilities': array([0.84379005, 0.1562099 ], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([1.8905865], dtype=float32), 'logistic': array([0.8688224], dtype=float32), 'probabilities': array([0.1311776, 0.8688224], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([2.1999667], dtype=float32), 'logistic': array([0.90024656], dtype=float32), 'probabilities': array([0.09975349, 0.90024656], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.98448634], dtype=float32), 'logistic': array([0.7279975], dtype=float32), 'probabilities': array([0.2720025, 0.7279975], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([-2.7839499], dtype=float32), 'logistic': array([0.05819768], dtype=float32), 'probabilities': array([0.9418023 , 0.05819768], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([1.3115588], dtype=float32), 'logistic': array([0.7877739], dtype=float32), 'probabilities': array([0.21222611, 0.7877739 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.8499193], dtype=float32), 'logistic': array([0.7005502], dtype=float32), 'probabilities': array([0.29944977, 0.7005502 ], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
{'logits': array([0.8553443], dtype=float32), 'logistic': array([0.70168704], dtype=float32), 'probabilities': array([0.298313 , 0.70168704], dtype=float32), 'class_ids': array([1]), 'classes': array([b'1'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}