from logging import getLogger
import torch
import numpy as np
from libcity.model import loss
from libcity.model.abstract_traffic_state_model import AbstractTrafficStateModel
from torch.nn import functional as F
import torch.nn as nn
[docs]def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
[docs]def cosine_similarity_torch(x1, x2=None, eps=1e-8):
x2 = x1 if x2 is None else x2
w1 = x1.norm(p=2, dim=1, keepdim=True)
w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)
[docs]def sample_gumbel(device, shape, eps=1e-20):
U = torch.rand(shape).to(device)
return -torch.autograd.Variable(torch.log(-torch.log(U + eps) + eps))
[docs]def gumbel_softmax_sample(device, logits, temperature, eps=1e-10):
sample = sample_gumbel(device, logits.size(), eps=eps)
y = logits + sample
return F.softmax(y / temperature, dim=-1)
[docs]def gumbel_softmax(device, logits, temperature, hard=False, eps=1e-10):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
"""
y_soft = gumbel_softmax_sample(device, logits, temperature=temperature, eps=eps)
if hard:
shape = logits.size()
_, k = y_soft.data.max(-1)
y_hard = torch.zeros(*shape).to(device)
y_hard = y_hard.zero_().scatter_(-1, k.view(shape[:-1] + (1,)), 1.0)
y = torch.autograd.Variable(y_hard - y_soft.data) + y_soft
else:
y = y_soft
return y
[docs]class GCONV(nn.Module):
def __init__(self, num_nodes, max_diffusion_step, device, input_dim, hid_dim, output_dim, adj_mx,
bias_start=0.0):
super().__init__()
self._num_nodes = num_nodes
self._max_diffusion_step = max_diffusion_step
self._device = device
self._num_matrices = self._max_diffusion_step + 1 # Ks
self._output_dim = output_dim
input_size = input_dim + hid_dim
shape = (input_size * self._num_matrices, self._output_dim)
self.weight = torch.nn.Parameter(torch.empty(*shape, device=self._device))
self.biases = torch.nn.Parameter(torch.empty(self._output_dim, device=self._device))
self.adj_mx = adj_mx
torch.nn.init.xavier_normal_(self.weight)
torch.nn.init.constant_(self.biases, bias_start)
@staticmethod
def _concat(x, x_):
x_ = x_.unsqueeze(0)
return torch.cat([x, x_], dim=0)
[docs] def forward(self, inputs, state):
# 对X(t)和H(t-1)做图卷积,并加偏置bias
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
batch_size = inputs.shape[0]
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
inputs_and_state = torch.cat([inputs, state], dim=2)
# (batch_size, num_nodes, total_arg_size(input_dim+state_dim))
input_size = inputs_and_state.size(2) # =total_arg_size
x = inputs_and_state
# T0=I x0=T0*x=x
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
x = torch.unsqueeze(x0, 0) # (1, num_nodes, total_arg_size * batch_size)
# 3阶[T0,T1,T2]Chebyshev多项式近似g(theta)
# 把图卷积公式中的~L替换成了随机游走拉普拉斯D^(-1)*W
if self._max_diffusion_step == 0:
pass
else:
# T1=L x1=T1*x=L*x
x1 = torch.sparse.mm(self.adj_mx, x0) # supports: n*n; x0: n*(total_arg_size * batch_size)
x = self._concat(x, x1) # (2, num_nodes, total_arg_size * batch_size)
for k in range(2, self._max_diffusion_step + 1):
# T2=2LT1-T0=2L^2-1 x2=T2*x=2L^2x-x=2L*x1-x0...
# T3=2LT2-T1=2L(2L^2-1)-L x3=2L*x2-x1...
x2 = 2 * torch.sparse.mm(self.adj_mx, x1) - x0
x = self._concat(x, x2) # (3, num_nodes, total_arg_size * batch_size)
x1, x0 = x2, x1 # 循环
# x.shape (Ks, num_nodes, total_arg_size * batch_size)
# Ks = len(supports) * self._max_diffusion_step + 1
x = torch.reshape(x, shape=[self._num_matrices, self._num_nodes, input_size, batch_size])
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, num_matrices)
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * self._num_matrices])
x = torch.matmul(x, self.weight) # (batch_size * self._num_nodes, self._output_dim)
x += self.biases
# Reshape res back to 2D: (batch_size * num_node, state_dim) -> (batch_size, num_node * state_dim)
return torch.reshape(x, [batch_size, self._num_nodes * self._output_dim])
[docs]class FC(nn.Module):
def __init__(self, num_nodes, device, input_dim, hid_dim, output_dim, bias_start=0.0):
super().__init__()
self._num_nodes = num_nodes
self._device = device
self._output_dim = output_dim
input_size = input_dim + hid_dim
shape = (input_size, self._output_dim)
self.weight = torch.nn.Parameter(torch.empty(*shape, device=self._device))
self.biases = torch.nn.Parameter(torch.empty(self._output_dim, device=self._device))
torch.nn.init.xavier_normal_(self.weight)
torch.nn.init.constant_(self.biases, bias_start)
[docs] def forward(self, inputs, state):
batch_size = inputs.shape[0]
# Reshape input and state to (batch_size * self._num_nodes, input_dim/state_dim)
inputs = torch.reshape(inputs, (batch_size * self._num_nodes, -1))
state = torch.reshape(state, (batch_size * self._num_nodes, -1))
inputs_and_state = torch.cat([inputs, state], dim=-1)
# (batch_size * self._num_nodes, input_size(input_dim+state_dim))
value = torch.sigmoid(torch.matmul(inputs_and_state, self.weight))
# (batch_size * self._num_nodes, self._output_dim)
value += self.biases
# Reshape res back to 2D: (batch_size * num_node, state_dim) -> (batch_size, num_node * state_dim)
return torch.reshape(value, [batch_size, self._num_nodes * self._output_dim])
[docs]class DCGRUCell(nn.Module):
def __init__(self, input_dim, num_units, adj_mx, max_diffusion_step, num_nodes, device, nonlinearity='tanh',
filter_type="laplacian", use_gc_for_ru=True):
"""
Args:
input_dim:
num_units:
adj_mx:
max_diffusion_step:
num_nodes:
device:
nonlinearity:
filter_type: "laplacian", "random_walk", "dual_random_walk"
use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
"""
super().__init__()
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
self._num_nodes = num_nodes
self._num_units = num_units
self._device = device
self._adj_mx = self._calculate_random_walk_matrix(adj_mx).t()
self._max_diffusion_step = max_diffusion_step
self._use_gc_for_ru = use_gc_for_ru
self.device = device
if self._use_gc_for_ru:
self._fn = GCONV(self._num_nodes, self._max_diffusion_step, self._device,
input_dim=input_dim, hid_dim=self._num_units, output_dim=2*self._num_units,
adj_mx=self._adj_mx, bias_start=1.0)
else:
self._fn = FC(self._num_nodes, self._device, input_dim=input_dim,
hid_dim=self._num_units, output_dim=2*self._num_units, bias_start=1.0)
self._gconv = GCONV(self._num_nodes, self._max_diffusion_step, self._device,
input_dim=input_dim, hid_dim=self._num_units, output_dim=self._num_units,
adj_mx=self._adj_mx, bias_start=0.0)
@staticmethod
def _build_sparse_matrix(lap, device):
lap = lap.tocoo()
indices = np.column_stack((lap.row, lap.col))
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
lap = torch.sparse_coo_tensor(indices.T, lap.data, lap.shape, device=device)
return lap
def _calculate_random_walk_matrix(self, adj_mx):
adj_mx = adj_mx + torch.eye(int(adj_mx.shape[0])).to(self._device)
d = torch.sum(adj_mx, 1)
d_inv = 1. / d
d_inv = torch.where(torch.isinf(d_inv), torch.zeros(d_inv.shape).to(self._device), d_inv)
d_mat_inv = torch.diag(d_inv)
random_walk_mx = torch.mm(d_mat_inv, adj_mx)
return random_walk_mx
[docs] def forward(self, inputs, hx):
"""
Gated recurrent unit (GRU) with Graph Convolution.
Args:
inputs: (B, num_nodes * input_dim)
hx: (B, num_nodes * rnn_units)
Returns:
torch.tensor: shape (B, num_nodes * rnn_units)
"""
output_size = 2 * self._num_units
value = torch.sigmoid(self._fn(inputs, hx)) # (batch_size, num_nodes * output_size)
value = torch.reshape(value, (-1, self._num_nodes, output_size)) # (batch_size, num_nodes, output_size)
r, u = torch.split(tensor=value, split_size_or_sections=self._num_units, dim=-1)
r = torch.reshape(r, (-1, self._num_nodes * self._num_units)) # (batch_size, num_nodes * _num_units)
u = torch.reshape(u, (-1, self._num_nodes * self._num_units)) # (batch_size, num_nodes * _num_units)
c = self._gconv(inputs, r * hx) # (batch_size, num_nodes * _num_units)
if self._activation is not None:
c = self._activation(c)
new_state = u * hx + (1.0 - u) * c
return new_state # (batch_size, num_nodes * _num_units)
[docs]class Seq2SeqAttrs:
def __init__(self, config, data_feature):
self.max_diffusion_step = int(config.get('max_diffusion_step', 2))
self.cl_decay_steps = int(config.get('cl_decay_steps', 1000))
self.filter_type = config.get('filter_type', 'laplacian')
self.num_nodes = int(data_feature.get('num_nodes', 1))
# print(f"num nodes is {self.num_nodes}")
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
self.rnn_units = int(config.get('rnn_units'))
self.hidden_state_size = self.num_nodes * self.rnn_units
self.input_dim = int(data_feature.get('feature_dim'))
self.device = config.get('device', torch.device('cpu'))
[docs]class EncoderModel(nn.Module, Seq2SeqAttrs):
def __init__(self, config, data_feature, adj_mx, device):
nn.Module.__init__(self)
Seq2SeqAttrs.__init__(self, config, data_feature)
self.device = device
self.seq_len = int(config.get('input_window', 1)) # for the encoder
self.dcgru_layers = nn.ModuleList()
self.dcgru_layers.append(DCGRUCell(self.input_dim, self.rnn_units, adj_mx, self.max_diffusion_step,
self.num_nodes, self.device, filter_type=self.filter_type))
for i in range(1, self.num_rnn_layers):
self.dcgru_layers.append(DCGRUCell(self.rnn_units, self.rnn_units, adj_mx, self.max_diffusion_step,
self.num_nodes, self.device, filter_type=self.filter_type))
[docs] def forward(self, inputs, hidden_state=None):
"""
Encoder forward pass.
Args:
inputs: shape (batch_size, self.num_nodes * self.input_dim)
hidden_state: (num_layers, batch_size, self.hidden_state_size),
optional, zeros if not provided, hidden_state_size = num_nodes * rnn_units
Returns:
tuple: tuple contains:
output: shape (batch_size, self.hidden_state_size) \n
hidden_state: shape (num_layers, batch_size, self.hidden_state_size) \n
(lower indices mean lower layers)
"""
batch_size, _ = inputs.size()
if hidden_state is None:
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), device=self.device)
hidden_states = []
output = inputs
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
next_hidden_state = dcgru_layer(output, hidden_state[layer_num])
# next_hidden_state: (batch_size, self.num_nodes * self.rnn_units)
hidden_states.append(next_hidden_state)
output = next_hidden_state # 循环
return output, torch.stack(hidden_states) # runs in O(num_layers) so not too slow
[docs]class DecoderModel(nn.Module, Seq2SeqAttrs):
def __init__(self, config, data_feature, adj_mx, device):
nn.Module.__init__(self)
Seq2SeqAttrs.__init__(self, config, data_feature)
self.device = device
self.output_dim = config.get('output_dim', 1)
self.horizon = int(config.get('output_window', 1))
self.projection_layer = nn.Linear(self.rnn_units, self.output_dim)
self.dcgru_layers = nn.ModuleList()
self.dcgru_layers.append(DCGRUCell(self.output_dim, self.rnn_units, adj_mx, self.max_diffusion_step,
self.num_nodes, self.device, filter_type=self.filter_type))
for i in range(1, self.num_rnn_layers):
self.dcgru_layers.append(DCGRUCell(self.rnn_units, self.rnn_units, adj_mx, self.max_diffusion_step,
self.num_nodes, self.device, filter_type=self.filter_type))
[docs] def forward(self, inputs, hidden_state=None):
"""
Decoder forward pass.
Args:
inputs: shape (batch_size, self.num_nodes * self.output_dim)
hidden_state: (num_layers, batch_size, self.hidden_state_size),
optional, zeros if not provided, hidden_state_size = num_nodes * rnn_units
Returns:
tuple: tuple contains:
output: shape (batch_size, self.num_nodes * self.output_dim) \n
hidden_state: shape (num_layers, batch_size, self.hidden_state_size) \n
(lower indices mean lower layers)
"""
hidden_states = []
output = inputs
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
next_hidden_state = dcgru_layer(output, hidden_state[layer_num])
# next_hidden_state: (batch_size, self.num_nodes * self.rnn_units)
hidden_states.append(next_hidden_state)
output = next_hidden_state
projected = self.projection_layer(output.view(-1, self.rnn_units))
output = projected.view(-1, self.num_nodes * self.output_dim)
return output, torch.stack(hidden_states)
[docs]class GTS(AbstractTrafficStateModel, Seq2SeqAttrs):
def __init__(self, config, data_feature):
"""
构造模型
:param config: 源于各种配置的配置字典
:param data_feature: 从数据集Dataset类的`get_data_feature()`接口返回的必要的数据相关的特征
"""
super().__init__(config, data_feature)
self.config = config
self.device = config.get('device', torch.device('cpu'))
self.adj_mx = torch.Tensor(data_feature.get('adj_mx')).to(self.device)
Seq2SeqAttrs.__init__(self, self.config, data_feature)
self.seq_len = int(config.get('input_window', 1)) # for the encoder
self.horizon = int(config.get('output_window', 1)) # for the decoder
self.encoder_model = EncoderModel(self.config, data_feature, self.adj_mx, self.device)
self.decoder_model = DecoderModel(self.config, data_feature, self.adj_mx, self.device)
self._logger = getLogger()
# 此处 adj_mx 作用是在训练自动图结构推断时起到参考作用
self.adj_mx = torch.Tensor(data_feature.get('adj_mx')).to(self.device)
# print(f"ADJMX={self.adj_mx}")
self.cl_decay_steps = config.get('cl_decay_steps', 1000)
self.use_curriculum_learning = config.get('use_curriculum_learning', True)
self.temperature = config.get('temperature', 0.5)
self.epoch_use_regularization = config.get('epoch_use_regularization', 50)
self.num_nodes = self.data_feature.get('num_nodes', 1)
self.num_batches = self.data_feature.get('num_batches', 1)
self._scaler = self.data_feature.get('scaler')
self.feature_dim = self.data_feature.get('feature_dim', 1)
self.output_dim = self.data_feature.get('output_dim', 1)
self.ext_dim = self.data_feature.get('ext_dim', 1)
train_feas = self.data_feature.get('train_data') # (num_samples, num_nodes)
self.node_feas = torch.Tensor(train_feas).to(self.device)
self.kernal_size = config.get('kernal_size', 10)
self.dim_fc = (self.node_feas.shape[0] - 2 * self.kernal_size + 2) * 16
self.embedding_dim = config.get('embedding_dim', 100)
self.conv1 = torch.nn.Conv1d(1, 8, self.kernal_size, stride=1)
self.conv2 = torch.nn.Conv1d(8, 16, self.kernal_size, stride=1)
self.hidden_drop = torch.nn.Dropout(0.2)
# print(f"FC shape={self.dim_fc}, {self.embedding_dim}")
self.fc = torch.nn.Linear(self.dim_fc, self.embedding_dim)
self.bn1 = torch.nn.BatchNorm1d(8)
self.bn2 = torch.nn.BatchNorm1d(16)
self.bn3 = torch.nn.BatchNorm1d(self.embedding_dim)
self.fc_out = nn.Linear(self.embedding_dim * 2, self.embedding_dim)
self.fc_cat = nn.Linear(self.embedding_dim, 2)
def encode_onehot(labels):
classes = set(labels)
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
return labels_onehot
# Generate off-diagonal interaction graph
off_diag = np.ones([self.num_nodes, self.num_nodes])
rel_rec = np.array(encode_onehot(np.where(off_diag)[0]), dtype=np.float32)
rel_send = np.array(encode_onehot(np.where(off_diag)[1]), dtype=np.float32)
self.rel_rec = torch.FloatTensor(rel_rec).to(self.device)
self.rel_send = torch.FloatTensor(rel_send).to(self.device)
self.input_dim = self.feature_dim
# print(f"feature_dim = {self.input_dim}")
def _compute_sampling_threshold(self, batches_seen):
return self.cl_decay_steps / (
self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))
[docs] def encoder(self, inputs):
"""
Encoder forward pass
:param inputs: shape (seq_len, batch_size, num_sensor * input_dim)
:return: encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size)
"""
encoder_hidden_state = None
for t in range(self.encoder_model.seq_len):
_, encoder_hidden_state = self.encoder_model(inputs[t], encoder_hidden_state)
return encoder_hidden_state
[docs] def decoder(self, encoder_hidden_state, labels=None, batches_seen=None):
"""
Decoder forward pass
:param encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size)
:param labels: (self.horizon, batch_size, self.num_nodes * self.output_dim) [optional, not exist for inference]
:param batches_seen: global step [optional, not exist for inference]
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
"""
batch_size = encoder_hidden_state.size(1)
go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim), device=self.device)
decoder_hidden_state = encoder_hidden_state
decoder_input = go_symbol
outputs = []
for t in range(self.decoder_model.horizon):
decoder_output, decoder_hidden_state = \
self.decoder_model(decoder_input, decoder_hidden_state)
decoder_input = decoder_output
outputs.append(decoder_output)
if self.training and self.use_curriculum_learning:
c = np.random.uniform(0, 1)
if c < self._compute_sampling_threshold(batches_seen):
decoder_input = labels[t]
outputs = torch.stack(outputs)
return outputs
def _prepare_data_x(self, x):
x = x.float()
x = x.permute(1, 0, 2, 3)
batch_size = x.size(1)
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
return x
def _prepare_data(self, x, y):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
return x.to(self.device), y.to(self.device)
[docs] def _get_x_y(self, x, y):
"""
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
:param y: shape (batch_size, horizon, num_sensor, input_dim)
:returns x shape (seq_len, batch_size, num_sensor, input_dim)
y shape (horizon, batch_size, num_sensor, input_dim)
"""
x = x.float()
y = y.float()
self._logger.debug("X: {}".format(x.size()))
self._logger.debug("y: {}".format(y.size()))
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
return x, y
[docs] def _get_x_y_in_correct_dims(self, x, y):
"""
:param x: shape (seq_len, batch_size, num_sensor, input_dim)
:param y: shape (horizon, batch_size, num_sensor, input_dim)
:return: x: shape (seq_len, batch_size, num_sensor * input_dim)
y: shape (horizon, batch_size, num_sensor * output_dim)
"""
batch_size = x.size(1)
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
y = y[..., :self.output_dim].contiguous().view(
self.horizon, batch_size, self.num_nodes * self.output_dim)
return x, y
[docs] def forward(self, batch, batches_seen=None):
batch_size = batch['X'].size(0)
if batch['y'] is not None:
inputs, labels = self._prepare_data(batch['X'], batch['y'])
# print(f"y = {batch['y'].shape}")
# print(f"labels = {labels.shape}")
else:
inputs = self._prepare_data_x(batch['X'])
labels = None
# 图结构的推断过程
x = self.node_feas.transpose(1, 0).view(self.num_nodes, 1, -1) # [207, 1, 24000]
x = self.conv1(x) # [207, 8, 23991]
x = F.relu(x)
x = self.bn1(x)
# x = self.hidden_drop(x)
x = self.conv2(x) # [207, 16, 23982]
x = F.relu(x)
x = self.bn2(x)
x = x.view(self.num_nodes, -1) # [207, 383712]
x = self.fc(x)
x = F.relu(x)
x = self.bn3(x)
receivers = torch.matmul(self.rel_rec, x)
senders = torch.matmul(self.rel_send, x)
x = torch.cat([senders, receivers], dim=1)
x = torch.relu(self.fc_out(x))
x = self.fc_cat(x)
adj = gumbel_softmax(self.device, x, temperature=self.temperature, hard=True)
adj = adj[:, 0].clone().reshape(self.num_nodes, -1)
mask = torch.eye(self.num_nodes, self.num_nodes).bool().to(self.device)
adj.masked_fill_(mask, 0)
encoder_hidden_state = self.encoder(inputs)
self._logger.debug("Encoder complete, starting decoder")
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
self._logger.debug("Decoder complete")
# print(f"shape of output = {outputs.shape}")
orig_out = outputs.view(self.horizon, batch_size, self.num_nodes, self.output_dim).permute(1, 0, 2, 3)
return orig_out, x[:, 0].clone().reshape(self.num_nodes, -1)
[docs] def calculate_loss(self, batch, batches_seen=None):
y_true = batch['y']
epoch = batches_seen // self.num_batches
self._logger.debug(f"EPOCH = {epoch}, bep={batches_seen}")
y_predicted, mid_output = self.forward(batch, batches_seen)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
# 根据训练轮数,选择性地加入正则项
loss_1 = loss.masked_mae_torch(y_predicted, y_true)
if epoch < self.epoch_use_regularization:
pred = torch.sigmoid(mid_output.view(mid_output.shape[0] * mid_output.shape[1]))
# print(f"shape = {mid_output.shape}")
# print(f"aview = {self.adj_mx.view(mid_output.shape[0] * mid_output.shape[1])}")
true_label = self.adj_mx.view(mid_output.shape[0] * mid_output.shape[1]).to(self.device)
compute_loss = torch.nn.BCELoss()
loss_g = compute_loss(pred, true_label)
self._logger.debug(f"loss_g = {loss_g}, loss_1 = {loss_1}")
loss_t = loss_1 + loss_g
return loss_t
else:
self._logger.debug(f"loss_1 = {loss_1}")
return loss_1
[docs] def predict(self, batch, batches_seen=None):
return self.forward(batch, batches_seen)[0]