Source code for libcity.model.trajectory_loc_prediction.GeoSAN

# coding: utf-8
from __future__ import print_function
from __future__ import division

from torch.nn.utils.rnn import pad_sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from libcity.model.abstract_model import AbstractModel
import math


[docs]class GeoSAN(AbstractModel): def __init__(self, config, data_feature): super().__init__(config, data_feature) self.device = config['device'] # depend on dataset self.num_neg = config['executor_config']['train']['num_negative_samples'] self.temperature = config['executor_config']['train']['temperature'] # from dataset # from train_dataset!! nuser = data_feature['nuser'] nloc = data_feature['nloc'] ntime = data_feature['ntime'] nquadkey = data_feature['nquadkey'] # from config user_dim = int(config['model_config']['user_embedding_dim']) loc_dim = int(config['model_config']['location_embedding_dim']) time_dim = int(config['model_config']['time_embedding_dim']) reg_dim = int(config['model_config']['region_embedding_dim']) # nhid = int(config['model_config']['hidden_dim_encoder']) nhead_enc = int(config['model_config']['num_heads_encoder']) # nhead_dec = int(config['model_config']['num_heads_decoder']) nlayers = int(config['model_config']['num_layers_encoder']) dropout = float(config['model_config']['dropout']) extra_config = config['model_config']['extra_config'] # print(f"nloc: {nloc} \t loc_dim: {loc_dim}") # essential self.emb_loc = Embedding(nloc, loc_dim, zeros_pad=True, scale=True) self.emb_reg = Embedding(nquadkey, reg_dim, zeros_pad=True, scale=True) # optional self.emb_user = Embedding(nuser, user_dim, zeros_pad=True, scale=True) self.emb_time = Embedding(ntime, time_dim, zeros_pad=True, scale=True) ninp = user_dim pos_encoding = extra_config.get("position_encoding", "transformer") if pos_encoding == "embedding": self.pos_encoder = PositionalEmbedding(loc_dim + reg_dim, dropout) elif pos_encoding == "transformer": self.pos_encoder = PositionalEncoding(loc_dim + reg_dim, dropout) self.enc_layer = TransformerEncoderLayer(loc_dim + reg_dim, nhead_enc, loc_dim + reg_dim, dropout) self.encoder = TransformerEncoder(self.enc_layer, nlayers) self.region_pos_encoder = PositionalEmbedding(reg_dim, dropout, max_len=20) self.region_enc_layer = TransformerEncoderLayer(reg_dim, 1, reg_dim, dropout=dropout) self.region_encoder = TransformerEncoder(self.region_enc_layer, 2) if not extra_config.get("use_location_only", False): if extra_config.get("embedding_fusion", "multiply") == "concat": if extra_config.get("user_embedding", False): self.lin = nn.Linear(user_dim + loc_dim + reg_dim + time_dim, ninp) else: self.lin = nn.Linear(loc_dim + reg_dim, ninp) ident_mat = torch.eye(ninp) self.register_buffer('ident_mat', ident_mat) self.layer_norm = nn.LayerNorm(ninp) self.extra_config = extra_config self.dropout = dropout
[docs] def predict(self, batch): """ Args: batch (Batch): a batch of input Returns: torch.tensor: predict result of this batch """ user, loc, time, region, trg, trg_reg, trg_nov, sample_probs, ds = batch user = user.to(self.device) loc = loc.to(self.device) time = time.to(self.device) region = region.to(self.device) trg = trg.to(self.device) trg_reg = trg_reg.to(self.device) sample_probs = sample_probs.to(self.device) src_mask = pad_sequence([torch.zeros(e, dtype=torch.bool).to(self.device) for e in ds], batch_first=True, padding_value=True) att_mask = GeoSAN._generate_square_mask_(max(ds), self.device) if self.training: output = self.forward(user, loc, region, time, att_mask, src_mask, trg, trg_reg, att_mask.repeat(self.num_neg + 1, 1)) else: output = self.forward(user, loc, region, time, att_mask, src_mask, trg, trg_reg, None, ds=ds) return output
@staticmethod def _generate_square_mask_(sz, device): mask = (torch.triu(torch.ones(sz, sz).to(device)) == 1).transpose(0, 1) mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) return mask
[docs] def calculate_loss(self, batch): """ Args: batch (Batch): a batch of input Returns: torch.tensor: return training loss """ # only support "WeightedProbBinaryCELoss" user, loc, time, region, trg, trg_reg, trg_nov, sample_probs, ds = batch user = user.to(self.device) loc = loc.to(self.device) time = time.to(self.device) region = region.to(self.device) trg = trg.to(self.device) trg_reg = trg_reg.to(self.device) sample_probs = sample_probs.to(self.device) src_mask = pad_sequence([torch.zeros(e, dtype=torch.bool).to(self.device) for e in ds], batch_first=True, padding_value=True) att_mask = self._generate_square_mask_(max(ds), self.device) if self.training: output = self.forward(user, loc, region, time, att_mask, src_mask, trg, trg_reg, att_mask.repeat(self.num_neg + 1, 1)) else: output = self.forward(user, loc, region, time, att_mask, src_mask, trg, trg_reg, None, ds=ds) # shape: [(1+K)*L, N] output = output.view(-1, loc.size(0), loc.size(1)).permute(2, 1, 0) # shape: [N, L, 1+K] pos_score, neg_score = output.split([1, self.num_neg], -1) weight = F.softmax(neg_score / self.temperature - torch.log(sample_probs), -1) loss = -F.logsigmoid(pos_score.squeeze()) + torch.sum(F.softplus(neg_score) * weight, dim=-1) keep = pad_sequence([torch.ones(e, dtype=torch.float32).to(self.device) for e in ds], batch_first=True) loss = torch.sum(loss * keep) / torch.sum(torch.tensor(ds).to(self.device)) return loss
[docs] def forward(self, src_user, src_loc, src_reg, src_time, src_square_mask, src_binary_mask, trg_loc, trg_reg, mem_mask, ds=None): loc_emb_src = self.emb_loc(src_loc) if self.extra_config.get("user_location_only", False): src = loc_emb_src else: user_emb_src = self.emb_user(src_user) # (L, N, LEN_QUADKEY, REG_DIM) reg_emb = self.emb_reg(src_reg) reg_emb = reg_emb.view(reg_emb.size(0) * reg_emb.size(1), reg_emb.size(2), reg_emb.size(3)).permute(1, 0, 2) # (LEN_QUADKEY, L * N, REG_DIM) reg_emb = self.region_pos_encoder(reg_emb) reg_emb = self.region_encoder(reg_emb) # avg pooling reg_emb = torch.mean(reg_emb, dim=0) # reg_emb, _ = self.region_gru_encoder(reg_emb, self.h_0.expand(4, reg_emb.size(1), -1).contiguous()) # reg_emb = reg_emb[-1, :, :] # (L, N, REG_DIM) reg_emb = reg_emb.view(loc_emb_src.size(0), loc_emb_src.size(1), reg_emb.size(1)) time_emb = self.emb_time(src_time) if self.extra_config.get("embedding_fusion", "multiply") == "multiply": if self.extra_config.get("user_embedding", False): src = loc_emb_src * reg_emb * time_emb * user_emb_src else: src = loc_emb_src * reg_emb * time_emb else: if self.extra_config.get("user_embedding", False): src = torch.cat([user_emb_src, loc_emb_src, reg_emb, time_emb], dim=-1) else: src = torch.cat([loc_emb_src, reg_emb], dim=-1) if self.extra_config.get("size_sqrt_regularize", True): src = src * math.sqrt(src.size(-1)) src = self.pos_encoder(src) # shape: [L, N, ninp] src = self.encoder(src, mask=src_square_mask) # shape: [(1+K)*L, N, loc_dim] loc_emb_trg = self.emb_loc(trg_loc) reg_emb_trg = self.emb_reg(trg_reg) # [(1+K)*L, N, LEN_QUADKEY, REG_DIM] # (LEN_QUADKEY, (1+K)*L * N, REG_DIM) reg_emb_trg = reg_emb_trg.view(reg_emb_trg.size(0) * reg_emb_trg.size(1), reg_emb_trg.size(2), reg_emb_trg.size(3)).permute(1, 0, 2) reg_emb_trg = self.region_pos_encoder(reg_emb_trg) reg_emb_trg = self.region_encoder(reg_emb_trg) reg_emb_trg = torch.mean(reg_emb_trg, dim=0) # [(1+K)*L, N, REG_DIM] reg_emb_trg = reg_emb_trg.view(loc_emb_trg.size(0), loc_emb_trg.size(1), reg_emb_trg.size(1)) loc_emb_trg = torch.cat([loc_emb_trg, reg_emb_trg], dim=-1) if self.extra_config.get("use_attention_as_decoder", False): # multi-head attention output, _ = F.multi_head_attention_forward( query=loc_emb_trg, key=src, value=src, embed_dim_to_check=src.size(2), num_heads=1, in_proj_weight=None, in_proj_bias=None, bias_k=None, bias_v=None, add_zero_attn=None, dropout_p=0.0, out_proj_weight=self.ident_mat, out_proj_bias=None, training=self.training, key_padding_mask=src_binary_mask, need_weights=False, attn_mask=mem_mask, use_separate_proj_weight=True, q_proj_weight=self.ident_mat, k_proj_weight=self.ident_mat, v_proj_weight=self.ident_mat ) if self.training: src = src.repeat(loc_emb_trg.size(0) // src.size(0), 1, 1) else: src = src[torch.tensor(ds) - 1, torch.arange(len(ds)), :] src = src.unsqueeze(0).repeat(loc_emb_trg.size(0), 1, 1) output += src output = self.layer_norm(output) else: # No attention if self.training: output = src.repeat(loc_emb_trg.size(0) // src.size(0), 1, 1) else: output = src[torch.tensor(ds) - 1, torch.arange(len(ds)), :] output = output.unsqueeze(0).repeat(loc_emb_trg.size(0), 1, 1) # shape: [(1+K)*L, N] output = torch.sum(output * loc_emb_trg, dim=-1) return output
[docs] def save(self, path): torch.save(self.state_dict(), path)
[docs] def load(self, path): self.load_state_dict(torch.load(path))
[docs]class Embedding(nn.Module): def __init__(self, vocab_size, num_units, zeros_pad=True, scale=True): '''Embeds a given Variable. Args: vocab_size: An int. Vocabulary size. num_units: An int. Number of embedding hidden units. zero_pad: A boolean. If True, all the values of the fist row (id 0) should be constant zeros. scale: A boolean. If True. the outputs is multiplied by sqrt num_units. ''' super(Embedding, self).__init__() self.vocab_size = vocab_size self.num_units = num_units self.zeros_pad = zeros_pad self.scale = scale self.lookup_table = nn.Parameter(torch.Tensor(vocab_size, num_units)) nn.init.xavier_normal_(self.lookup_table.data) if self.zeros_pad: self.lookup_table.data[0, :].fill_(0)
[docs] def forward(self, inputs): if self.zeros_pad: self.padding_idx = 0 else: self.padding_idx = -1 outputs = F.embedding( inputs, self.lookup_table, self.padding_idx, None, 2, False, False) # copied from torch.nn.modules.sparse.py if self.scale: outputs = outputs * (self.num_units ** 0.5) return outputs
[docs]class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe)
[docs] def forward(self, x): x = x + self.pe[:x.size(0), :] return self.dropout(x)
[docs]class PositionalEmbedding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=120): super(PositionalEmbedding, self).__init__() self.pos_emb_table = Embedding(max_len, d_model, zeros_pad=False, scale=False) pos_vector = torch.arange(max_len) self.dropout = nn.Dropout(p=dropout) self.register_buffer('pos_vector', pos_vector)
[docs] def forward(self, x): pos_emb = self.pos_emb_table(self.pos_vector[:x.size(0)].unsqueeze(1).repeat(1, x.size(1))) x += pos_emb return self.dropout(x)