libcity.model.traffic_speed_prediction.TemplateTSP¶
-
class
libcity.model.traffic_speed_prediction.TemplateTSP.
TemplateTSP
(config, data_feature)[source]¶ Bases:
libcity.model.abstract_traffic_state_model.AbstractTrafficStateModel
-
calculate_loss
(batch)[source]¶ 输入一个batch的数据,返回训练过程这个batch数据的loss,也就是需要定义一个loss函数。 :param batch: 输入数据,类字典,可以按字典的方法取数据 :return: training loss (tensor)
-
forward
(batch)[source]¶ 调用模型计算这个batch输入对应的输出,nn.Module必须实现的接口 :param batch: 输入数据,类字典,可以按字典的方法取数据 :return:
-
predict
(batch)[source]¶ 输入一个batch的数据,返回对应的预测值,一般应该是**多步预测**的结果 一般会调用上边定义的forward()方法 :param batch: 输入数据,类字典,可以按字典的方法取数据 :return: predict result of this batch (tensor)
-
training
: bool¶
-