Source code for libcity.model.abstract_model

import torch.nn as nn


[docs]class AbstractModel(nn.Module): def __init__(self, config, data_feature): nn.Module.__init__(self)
[docs] def predict(self, batch): """ Args: batch (Batch): a batch of input Returns: torch.tensor: predict result of this batch """
[docs] def calculate_loss(self, batch): """ Args: batch (Batch): a batch of input Returns: torch.tensor: return training loss """