libcity.model.trajectory_loc_prediction.TemplateTLP¶
-
class
libcity.model.trajectory_loc_prediction.TemplateTLP.
TemplateTLP
(config, data_feature)[source]¶ Bases:
libcity.model.abstract_model.AbstractModel
请参考开源模型代码,完成本文件的编写。请务必重写 __init__, predict, calculate_loss 三个方法。
-
calculate_loss
(batch)[source]¶ - 参数说明:
batch (libcity.data.batch): 类 dict 文件,其中包含的键值参见任务说明文件。
- 返回值:
- loss (pytorch.tensor): 可以调用 pytorch 实现的 loss 函数与 batch[‘target’]
目标值进行 loss 计算,并将计算结果返回。如模型有自己独特的 loss 计算方式则自行参考实现。
-
predict
(batch)[source]¶ - 参数说明:
batch (libcity.data.batch): 类 dict 文件,其中包含的键值参见任务说明文件。
- 返回值:
- score (pytorch.tensor): 对应张量 shape 应为 batch_size *
loc_size。这里返回的是模型对于输入当前轨迹的下一跳位置的预测值。
-
training
: bool¶
-