Source code for libcity.model.abstract_traffic_state_model

from libcity.model.abstract_model import AbstractModel


[docs]class AbstractTrafficStateModel(AbstractModel): def __init__(self, config, data_feature): self.data_feature = data_feature super().__init__(config, data_feature)
[docs] def predict(self, batch): """ 输入一个batch的数据,返回对应的预测值,一般应该是**多步预测**的结果,一般会调用nn.Moudle的forward()方法 Args: batch (Batch): a batch of input Returns: torch.tensor: predict result of this batch """
[docs] def calculate_loss(self, batch): """ 输入一个batch的数据,返回训练过程的loss,也就是需要定义一个loss函数 Args: batch (Batch): a batch of input Returns: torch.tensor: return training loss """