Source code for libcity.data.batch

import torch
import numpy as np


[docs]class Batch(object): def __init__(self, feature_name): """Summary of class here Args: feature_name (dict): key is the corresponding feature's name, and the value is the feature's data type """ self.data = {} self.feature_name = feature_name for key in feature_name: self.data[key] = [] def __getitem__(self, key): if key in self.data: return self.data[key] else: raise KeyError('{} is not in the batch'.format(key)) def __setitem__(self, key, value): if key in self.data: self.data[key] = value else: raise KeyError('{} is not in the batch'.format(key))
[docs] def append(self, item): """ append a new item into the batch Args: item (list): 一组输入,跟feature_name的顺序一致,feature_name即是这一组输入的名字 """ if len(item) != len(self.feature_name): raise KeyError('when append a batch, item is not equal length with feature_name') for i, key in enumerate(self.feature_name): self.data[key].append(item[i])
[docs] def to_tensor(self, device): """ 将数据self.data转移到device上 Args: device(torch.device): GPU/CPU设备 """ for key in self.data: if self.feature_name[key] == 'int': self.data[key] = torch.LongTensor(np.array(self.data[key])).to(device) elif self.feature_name[key] == 'float': self.data[key] = torch.FloatTensor(np.array(self.data[key])).to(device) else: raise TypeError( 'Batch to_tensor, only support int, float but you give {}'.format(self.feature_name[key]))
[docs] def to_ndarray(self): for key in self.data: if self.feature_name[key] == 'int': self.data[key] = np.array(self.data[key]) elif self.feature_name[key] == 'float': self.data[key] = np.array(self.data[key]) else: raise TypeError( 'Batch to_ndarray, only support int, float but you give {}'.format(self.feature_name[key]))
[docs]class BatchPAD(Batch): def __init__(self, feature_name, pad_item=None, pad_max_len=None): """Summary of class here Args: feature_name (dict): key is the corresponding feature's name, and the value is the feature's data type pad_item (dict): key is the feature name, and value is the padding value. We will just padding the feature in pad_item pad_max_len (dict): key is the feature name, and value is the max length of padded feature. use this parameter to truncate the feature. """ super().__init__(feature_name=feature_name) # 默认是根据 batch 中每个特征最长的长度来补齐,如果某个特征的长度超过了 pad_max_len 则进行剪切 self.pad_len = {} self.origin_len = {} # 用于得知补齐前轨迹的原始长度 self.pad_max_len = pad_max_len if pad_max_len is not None else {} self.pad_item = pad_item if pad_item is not None else {} for key in feature_name: self.data[key] = [] if key in self.pad_item: self.pad_len[key] = 0 self.origin_len[key] = []
[docs] def append(self, item): """ append a new item into the batch Args: item (list): 一组输入,跟feature_name的顺序一致,feature_name即是这一组输入的名字 """ if len(item) != len(self.feature_name): raise KeyError('when append a batch, item is not equal length with feature_name') for i, key in enumerate(self.feature_name): # 需保证 item 每个特征的顺序与初始化时传入的 feature_name 中特征的顺序一致 self.data[key].append(item[i]) if key in self.pad_item: self.origin_len[key].append(len(item[i])) if self.pad_len[key] < len(item[i]): # 保持 pad_len 是最大的 self.pad_len[key] = len(item[i])
[docs] def padding(self): """ 只提供对一维数组的特征进行补齐 """ for key in self.pad_item: # 只对在 pad_item 中的特征进行补齐 if key not in self.data: raise KeyError('when pad a batch, raise this error!') max_len = self.pad_len[key] if key in self.pad_max_len: max_len = min(self.pad_max_len[key], max_len) for i in range(len(self.data[key])): if len(self.data[key][i]) < max_len: self.data[key][i] += [self.pad_item[key]] * \ (max_len - len(self.data[key][i])) else: # 截取的原则是,抛弃前面的点 # 因为是时间序列嘛 self.data[key][i] = self.data[key][i][-max_len:] # 对于剪切了的,我们没办法还原,但至少不要使他出错 self.origin_len[key][i] = max_len
[docs] def get_origin_len(self, key): return self.origin_len[key]
[docs] def to_tensor(self, device): """ 将数据self.data转移到device上 Args: device(torch.device): GPU/CPU设备 """ for key in self.data: if self.feature_name[key] == 'int': self.data[key] = torch.LongTensor(np.array(self.data[key])).to(device) elif self.feature_name[key] == 'float': self.data[key] = torch.FloatTensor(np.array(self.data[key])).to(device) elif self.feature_name[key] == 'array of int': for i in range(len(self.data[key])): for j in range(len(self.data[key][i])): try: self.data[key][i][j] = torch.LongTensor(np.array(self.data[key][i][j])).to(device) except TypeError: print('device is ', device) exit() elif self.feature_name[key] == 'no_pad_int': for i in range(len(self.data[key])): self.data[key][i] = torch.LongTensor(np.array(self.data[key][i])).to(device) elif self.feature_name[key] == 'no_pad_float': for i in range(len(self.data[key])): self.data[key][i] = torch.FloatTensor(np.array(self.data[key][i])).to(device) elif self.feature_name[key] == 'no_tensor': pass else: raise TypeError( 'Batch to_tensor, only support int, float but you give {}'.format(self.feature_name[key]))