Source code for libcity.model.abstract_model
import torch.nn as nn
[docs]class AbstractModel(nn.Module):
def __init__(self, config, data_feature):
nn.Module.__init__(self)
[docs] def predict(self, batch):
"""
Args:
batch (Batch): a batch of input
Returns:
torch.tensor: predict result of this batch
"""
[docs] def calculate_loss(self, batch):
"""
Args:
batch (Batch): a batch of input
Returns:
torch.tensor: return training loss
"""