Source code for libcity.executor.geosan_executor

import torch
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
import time as Time
from libcity.executor.abstract_executor import AbstractExecutor
import random
from libcity.utils import get_evaluator


[docs]class GeoSANExecutor(AbstractExecutor): def __init__(self, config, model, data_feature): self.config = config self.device = self.config.get('device', torch.device('cpu')) self.model = model.to(self.device) self.evaluator = get_evaluator(config) self.exp_id = self.config.get('exp_id', None) self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id) self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(self.exp_id) self.tmp_path = './libcity/tmp/checkpoint/'
[docs] def train(self, train_dataloader, eval_dataloader): """ use data to train model with config Args: train_dataloader(torch.Dataloader): Dataloader eval_dataloader(torch.Dataloader): None """ if not os.path.exists(self.tmp_path): os.makedirs(self.tmp_path) num_epochs = self.config['executor_config']['train']['num_epochs'] optimizer = optim.Adam(self.model.parameters(), lr=float(self.config['executor_config']['optimizer']['learning_rate']), betas=(0.9, 0.98)) self.model.train() for epoch_idx in range(num_epochs): start_time = Time.time() running_loss = 0. processed_batch = 0 batch_iterator = tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=True) for batch_idx, batch in batch_iterator: optimizer.zero_grad() loss = self.model.calculate_loss(batch) loss.backward() optimizer.step() running_loss += loss.item() processed_batch += 1 batch_iterator.set_postfix_str(f"loss={loss.item():.4f}") save_name_tmp = 'ep_' + str(epoch_idx) + '.m' torch.save(self.model.state_dict(), self.tmp_path + save_name_tmp) epoch_time = Time.time() - start_time print("epoch {:>2d} completed.".format(epoch_idx + 1)) print("time taken: {:.2f} sec".format(epoch_time)) print("avg. loss: {:.4f}".format(running_loss / processed_batch)) print("epoch={:d}, loss={:.4f}".format(epoch_idx + 1, running_loss / processed_batch)) for rt, dirs, files in os.walk(self.tmp_path): for name in files: remove_path = os.path.join(rt, name) os.remove(remove_path) os.rmdir(self.tmp_path) print("training completed!")
[docs] def evaluate(self, test_dataloader): """ use model to test data Args: test_dataloader(torch.Dataloader): Dataloader """ self.evaluator.clear() self.model.eval() GeoSANExecutor.reset_random_seed(42) with torch.no_grad(): for _, batch in enumerate(test_dataloader): output = self.model.predict(batch) # shape: [(1+K)*L, N] self.evaluator.collect(output) self.evaluator.save_result(self.evaluate_res_dir)
[docs] def load_model(self, cache_name): """ 加载对应模型的 cache Args: cache_name(str): 保存的文件名 """ self.model.load(cache_name)
[docs] def save_model(self, cache_name): """ 将当前的模型保存到文件 Args: cache_name(str): 保存的文件名 """ if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) self.model.save(cache_name)
[docs] @staticmethod def reset_random_seed(seed): """ 重置随机数种子 Args: seed(int): 种子数 """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed)