libcity.evaluator.traj_loc_pred_evaluator¶
-
class
libcity.evaluator.traj_loc_pred_evaluator.
TrajLocPredEvaluator
(config)[source]¶ Bases:
libcity.evaluator.abstract_evaluator.AbstractEvaluator
-
collect
(batch)[source]¶ - Parameters
batch (dict) – contains three keys: uid, loc_true, and loc_pred.
uid (list) – 来自于 batch 中的 uid,通过索引可以确定 loc_true 与 loc_pred 中每一行(元素)是哪个用户的一次输入。
loc_true (list) – 期望地点(target),来自于 batch 中的 target。 对于负样本评估,loc_pred 中第一个点是 target 的置信度,后面的都是负样本的
loc_pred (matrix) – 实际上模型的输出,batch_size * output_dim.
-