Regression with LSTM
Image source: Deformation pattern caused by a seismic event over the L'Aquila area in central Italy. This interferogram was generated by Italy’s Istituto per il Rilevamento Elettromagnetico dell’ Ambiente (IREA-CNR) in Naples, Italy. Registered using an Envisat Advanced Synthetic Aperture Radar (ASAR) on 12 April 2009.
Problem definition: This use case focuses on the prediction of the impact that extreme events may have on the population. Specifically, we focus on the normalized total damages due to worldwide earthquakes. To do so, we will use climate data of each region.¶
This notebook aims to provide users with practical experience in utilising the toolbox. This tutorial also aims to provide a guide to tackle regression problems (referred to as Impact Assessment tasks) using 1D datasets.¶
To gain a comprehensive understanding of the toolbox's structure and detailed usage instructions, we highly recommend referring to the user guide available in the "Read the Docs" reference. Familiarising yourself with the user guide will ensure you have a solid foundation to make the most out of the tutorial and leverage the toolbox effectively.
1. PREREQUISITES
2. DATA
3. MODEL
4. TRAINING
5. EVALUATION
6. EXPLAINABLE AI
1. PREREQUISITES OF THE TOOLBOX¶
Please, start running the following cells to check out the contents of the toolbox, install its dependencies and import the required libraries to make use of it.
%load_ext autoreload
%autoreload 2
Check out the contents of the AIDE toolbox¶
a = %pwd
if a.split("/")[-1] != "AIDE":
%cd ../
%ls -h
/home/maria/Documents/AIDE_private/AIDE backbones/ databases/ experiments/ tutorials/ utils/ configs/ evaluators/ main.py* user_defined/
Folders with tunable functionalities:
- configs is a directory including config files with the tuneable parameters regarding the data, architecture, training and evaluation;
- databases stores available example datasets, as well as the datasets for which you would like to use the AIDE toolbox;
- tutorials contains several notebooks to serve as example of usage of the AIDE toolbox;
- user_defined includes the user defined files and functions;
Internal code:
- backbones defines the core classes of the toolbox;
- evaluators creates the necessary modules for a better assessment of results;
- utils contains generic functions used by the toolbox;
- main.py is the main script to execute the pipeline of the toolbox using the command window.
Import required libraries¶
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import itertools
from tqdm.auto import tqdm
from datetime import datetime
from scipy.interpolate import interp1d
from sklearn.preprocessing import RobustScaler
import time
import datetime
import seaborn as sns
# PYTORCH
from torch.utils.data import DataLoader
# PYTORCH LIGHTNING
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelSummary, ModelCheckpoint, EarlyStopping
# TORCHMETRICS
import torchmetrics
# DATASET CLASSES
from databases import *
# MODEL TEMPLATE CLASS
from backbones import *
# EVALUATION
from evaluators import *
# METRICS
from utils.misc import *
from utils.setup_config import setup
import yaml
from os.path import dirname, abspath
from pathlib import Path
import sys
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300
pd.set_option('display.max_rows', 50)
sns.set_style("darkgrid")
from PIL import Image
/home/maria/Documents/AIDE_private/aide_env/lib/python3.8/site-packages/torch/cuda/__init__.py:107: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.) return torch._C._cuda_getDeviceCount() > 0
2. DATA¶
2.1 The database¶
The dataset that we will use is the DATABASE OF EXTREME CLIMATE EVENTS AND ASSOCIATED IMPACTS ON ECOSYSTEMS AND SOCIETY (named XAIDA database in our Git repository) prepared by the members of Universität Leipzig inside the XAIDA project. The database summarizes the reported impacts of extreme climate and weather events listed by EM-DAT and it couples them with two-year time series of climate variables aggregated over the event locations.
Characteristics of the climate data:
- Temporal resolution: Spans from 1989 to 2021 as one year before and after per each event.
- Spatial resolution: 0.25º
- Variables:
- ERA5 eanalysis data: t2m, t2mmax, t2mmin, tp, pev.
- Calculated from ERA5: pet, SPEI30, SPEI90, SPEI180
- GLEAM: smsurf
In more detail: multidimensional array of aggregated climate variables for the disaster events
| # | Coordinates/ dimensions | Description |
|---|---|---|
| 1 | Event | Event identifier corresponding to disaster_number_country in the event data frame.eXtreme events: Artificial Intelligence for Detection and Attribution Database of extreme climate events and associated impacts on ecosystems and society |
| 2 | Landcover | The spatial aggregation was performed on the geometry where the disaster was reported, and also on the Urban/ agricultural and natural ecosystem land covers within that geometry |
| 3 | Time | Time axis of each climate variables for each event starting at the day -365 before event onset (corresponding to the start aggregation date in the event data frame), and ending 365 days after event onset (corresponding to the end aggregation date in the event dataframe) |
| 4 | Units | Climatological variables were aggregated per event and land cover type both in original units and normalized anomalies |
| # | Climatological variables | Description |
|---|---|---|
| 1 | pet | Potential evapotranspiration |
| 2 | pev | Potential evaporation (ERA5) |
| 3 | spei_30, spei_90, spei_180 | Standardized Precipitation - Evapotranspiration Index for 30, 90, and 180 days |
| 6 | t2m, t2mmax, t2mmin | 2m Air temperature: daily mean, maximum and minimum (ERA5) |
| 9 | tp | Total daily precipitation (ERA5) |
| 10 | Surface moisture | Surface moisture (GLEAM) |
Societal impact data:
- Human losses: extracted from EM-DAT
- Finantial damages: extracted from EM-DAT
- Population data: the gridded population data GPW3 and GPW4
- GDP data: the product used is the gridded global dataset for Gross Domestic Product (Kummu et al. 2018)
The database has data available from 9 disaster types, further defined by their respective subtypes. The number of events per disaster type considered in the societal dataset and their yearly breakdown are shown in Figure 1 and Figure 2 respectively.
2.2 Configuration files¶
Configuration files are essential in deep learning for their ability to separate settings and parameters from the code, promoting modularity and reproducibility. They provide flexibility and customisation options, allowing easy adjustments to hyperparameters, model architectures, and other aspects.
YOUR_CONFIG_FILE_NAME = "config_XAIDA_Earthquakes_IA"
EXPERIMENT_ID = "jupyter_" + YOUR_CONFIG_FILE_NAME.split("_")[1] + "_"+ YOUR_CONFIG_FILE_NAME.split("_")[2] + "_" + str(time.time())
current_d = dirname(abspath("__file__"))
config_path = current_d + "/configs/" + YOUR_CONFIG_FILE_NAME + ".yaml"
config = setup(config_path)
config['experiment_id'] = EXPERIMENT_ID
We first load the default configuration to visualize them:
config
{'name': 'AIDE',
'task': 'ImpactAssessment',
'from_scratch': True,
'best_run_path': '',
'save_path': 'experiments/',
'debug': False,
'data': {'name': 'XAIDA_IA',
'seed': 42,
'data_dim': 1,
'root': './databases/XAIDA/',
'index_data': 'index_data/xaida_impacts_disaster_dataframe.csv',
'clim_data': 'clim_data/DB_extreme_clim_events_impacts.zarr/',
'num_targets': 1,
'whole_ts': False,
'random_lead_time': 0,
'before_event': -5,
'after_event': 5,
'event_duration': 1,
'steps_per_batch': 'all',
'landcover': ['urban', 'natural'],
'input_size': 10,
'train_slice': {'start': 1989, 'end': 2010},
'val_slice': {'start': 2011, 'end': 2014},
'test_slice': {'start': 2015, 'end': 2021},
'units': 'original',
'impact_var': 'normalized_total_damages',
'extremes': ['Earthquake'],
'continents': ['Asia', 'Africa', 'Americas', 'Europe', 'Oceania'],
'features': ['pet',
'pev',
'spei_30',
'spei_90',
'spei_180',
't2m',
't2mmax',
't2mmin',
'tp',
'smsurf',
'label'],
'features_selected': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
'normalize': False,
'standardize_target': True},
'arch': {'user_defined': True,
'type': 'UD_LSTM_IA',
'params': {'input_size': 20,
'hidden_dim': 32,
'hidden_layers': 1,
'fc_hidden_dim': 16,
'dropout_p': 0.1},
'input_model_dim': 2,
'output_model_dim': 1},
'implementation': {'loss': {'user_defined': False,
'type': 'HuberLoss',
'package': 'torch.nn',
'activation': {'type': 'linear'},
'masked': False,
'params': {'reduction': 'none', 'delta': 0.5}},
'optimizer': {'type': 'Adam',
'lr': 0.0001,
'weight_decay': 0,
'gclip_value': -1},
'trainer': {'accelerator': 'cpu',
'devices': 1,
'epochs': 1000,
'batch_size': 16,
'monitor': {'split': 'val', 'metric': 'loss'},
'monitor_mode': 'min',
'early_stop': 50,
'save_dir': 'experiments/'},
'data_loader': {'num_workers': 16, 'sampler': 'None'}},
'evaluation': {'metrics': {'MeanSquaredError': {},
'PearsonCorrCoef': {},
'SpearmanCorrCoef': {}},
'xai': {'activate': True, 'params': {'type': 'Saliency', 'params': None}}},
'experiment_id': 'jupyter_XAIDA_Earthquakes_1694256291.7084587'}
We can now modify important settings to fit our needs. We will focus on selecting the events of Floods in Europe, from 1989 until 2021:
#Modify config parameters: Data preprocessor
train_start, train_end= 1989, 2010
val_start, val_end= 2011, 2014
test_start, test_end= 2015, 2021
#Extremes: ['Flood' 'Storm' 'Landslide' 'Earthquake' 'Volcanic activity'
# 'Mass movement (dry)' 'Extreme temperature' 'Drought' 'Wildfire']
#Continents: ['Asia' 'Africa' 'Oceania' 'Europe' 'Americas']
extremes= ['Earthquake']
continents= ['Asia', 'Africa', 'Oceania', 'Europe', 'Americas']
# Configuration of the lenght of our time series
config['data']['random_lead_time']= 0
config['data']['whole_ts']= False
config['data']['before_event']= -5 #-20
config['data']['after_event']= 5
config['data']['event_duration']= 1
config['data']['landcover']= ['urban', 'natural'] #'all'
# Configuration of the start-end years for our three splits (train, validation and test)
config['data']['train_slice']['start']= train_start
config['data']['train_slice']['end']= train_end
config['data']['test_slice']['start']= test_start
config['data']['test_slice']['end']= test_end
config['data']['val_slice']['start']= val_start
config['data']['val_slice']['end']= val_end
# Choice of type of extremes and target continents
config['data']['extremes']= extremes
config['data']['continents']= continents
# Variable that we are interested in predicting
config['data']['units'] = 'original'
config['data']['impact_var'] = 'normalized_total_damages' # 'normalized_total_damages_adjusted'
# Climate data used as input to the model
config['data']['features']= ['pet', 'pev', 'spei_30', 'spei_90', 'spei_180','t2m','t2mmax','t2mmin','tp', 'smsurf', 'label']
config['data']['features_selected']= [0,1,2,3,4,5,6,7,8,9]
config['data']['input_size']= 10
# Data pre-processing
config['data']['normalize']= True
config['data']['standardize_target']= True
# Architecture parameters
config['arch']['params']['input_size'] = 20
config['arch']['params']['hidden_dim'] = 32
config['arch']['params']['hidden_layers'] = 1
config['arch']['params']['fc_hidden_dim'] = 16
config['arch']['params']['dropout_p'] = 0.1
# Loss parameters
config['implementation']['loss']['params']['delta'] = 0.5 # 1, 0.1, 0.5
# Optimizer parameters
config['implementation']['optimizer']['gclip_value'] = -1
config['implementation']['optimizer']['lr'] = 0.0001 #0.001
# Training specifications: time and resources
config['implementation']['trainer']['accelerator']= 'cpu'
config['implementation']['trainer']['devices']= 1
config['implementation']['trainer']['epochs']= 1000
config['implementation']['trainer']['early_stop']= 50
config['implementation']['trainer']['batch_size']= 16
#Create experimental folder structure if it was not already created
if not Path(config['save_path']).name == config['experiment_id'].replace('/', ''):
save_path= Path(config['save_path']) / Path(config['experiment_id'])
config['save_path']= str(save_path.resolve())
save_path.mkdir(exist_ok=True, parents=True)
else:
save_path= Path(config['save_path'])
print(f'Logging experiment data at: {save_path}')
Logging experiment data at: experiments/jupyter_XAIDA_Earthquakes_1694256291.7084587
This modification can be done directly on the corresponding configuration file in the configs folder. The previous selection of parameters has been explicitly extracted for demonstration purposes.
2.3 Data loading and visualization¶
We will use the PytorchBackbone to hold the main functions of the pipeline:
The toolbox can be run from end-to-end through the terminal by executing the file main.py. For more details, please read the README provided in the AIDE GitHub repository.
regressor = PytorchBackbone(config)
The first step of the toolbox is the loading of the data. This process will retrieve the information as specified in the custom dataset class, provided in the databases folder by the user. All the variables specified in this section for the data exploration are custom-named.
# Load the data
regressor.load_data()
#Check loaded data: training events tags and count. This variable is specific of the custom dataset class.
regressor.train_loader.dataset.X.index.levels[0].unique()
Index(['1990-0009-PHL', '1990-0029-PER', '1990-0034-IRN', '1990-0040-PHL',
'1990-0099-CHN', '1990-0118-IRN', '1990-0170-CRI', '1990-0187-USA',
'1990-0189-PAK', '1990-0250-ITA',
...
'2010-0002-TJK', '2010-0017-HTI', '2010-0027-USA', '2010-0064-CHN',
'2010-0091-CHL', '2010-0158-MEX', '2010-0169-CHN', '2010-0463-NZL',
'2010-0574-SRB', '2010-0668-TWN'],
dtype='object', name='event', length=207)
# Define plotting function to visualise out time series
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.colors as mcolors
def plot_event(data, event_name, predicted=None, figsize=(10,5), limits=None,
color_list=list(mcolors.TABLEAU_COLORS)):
start_from = 1
y= data.y.loc[event_name]
x= data.X.loc[event_name]
x= x.droplevel(level=0) if isinstance(x.index, pd.MultiIndex) and len(x.index.levels) == 2 else x
event_tag = data.dataframe[np.array('label')][event_name]
event_tag = event_tag.droplevel(level=0) if isinstance(event_tag.index, pd.MultiIndex) and len(event_tag.index.levels) == 2 else event_tag
classes= data.classes
event_class = data.type_extreme.loc[event_name]
fig, axs = plt.subplots(1, 1, figsize=figsize)
axs.plot(x, label=x.columns)
if limits is not None:
axs.set_ylim(limits[0], limits[1])
ymin, ymax= axs.get_ylim()
y_onehot= np.eye(len(classes))[event_tag.values]
axs.stackplot(x.index, y_onehot[:, start_from:].swapaxes(0,1)*ymax,
labels=['GT: ' + cl for cl in classes[start_from:]],
alpha=0.3, colors=[np.minimum(np.array(mcolors.to_rgb(c))* 1.5, 1.) for c in color_list])
axs.legend(loc='upper left', bbox_to_anchor=(1.01, 1))
axs.grid(True)
axs.set_xlabel('Days relative to the event')
axs.set_title(event_name + ' - Target variable: ' + str(round(y,6)))
#Chose an event
idx = 15
#print(len(regressor.train_loader.dataset.selected_events))
#for idx in np.arange(len(regressor.train_loader.dataset.selected_events)):
event_example = regressor.train_loader.dataset.selected_events.index[idx]
# Visualise chosen event
plot_event(regressor.train_loader.dataset, event_example)
We now visualise the distribution of the variable that we want to predict:
# Plot the distribution of the impact variable that we want to predict
fig, axs = plt.subplots(1, 1, figsize=(5,3))
axs.set_title('Impact variable histogram')
axs.set_xlabel('Variable ' + regressor.train_loader.dataset.impacts.name)
axs.set_ylabel('Number of samples')
_ = axs.hist(regressor.train_loader.dataset.y, bins=100)
3. MODEL¶
Our input data is the set of selected climatological variables, which we explored in the previous section. To exploit the temporal information, we design a simple recurrent neural network. The core element is a Long Short-Term Memory cell, able of learning long-term dependencies between the data. After that, we concatenate two fully connected layers to further process the information.
According to the parameters of choice (see Section 2.2), our model's summary is as follows:
regressor.implement_model()
regressor.model
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback. GPU available: True (cuda), used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs `Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
PytorchModel(
(model): UD_LSTM_IA(
(lstm): LSTM(20, 32, batch_first=True)
(lstm_ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
(fc): Linear(in_features=32, out_features=16, bias=True)
(fc_ln): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(fc_top): Linear(in_features=16, out_features=1, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(loss): HuberLoss()
)
4. TRAINING¶
We make use of TensorBoard to see how our model has learnt. The train_val_losses drop-down will show a graphic depicting the performance of our model during all training steps.
# Load the TensorBoard notebook extension
%load_ext tensorboard
tensorboard_path = save_path / 'lightning_logs'
%tensorboard --logdir="$tensorboard_path" --port=6012
The first run of the tensorboard cell will provide a prompt saying "No dashboards are active for the current data set". This is the correct behaviour of the cell. Please, use the refresh symbol in the prompt once the training of the model has started (cell below) to follow in real time the learning process of the model.
model= regressor.train()
Missing logger folder: /home/maria/Documents/AIDE_private/AIDE/experiments/jupyter_XAIDA_Earthquakes_1694256291.7084587/lightning_logs | Name | Type | Params --------------------------------------------- 0 | model | UD_LSTM_IA | 7.6 K 1 | model.lstm | LSTM | 6.9 K 2 | model.lstm_ln | LayerNorm | 64 3 | model.fc | Linear | 528 4 | model.fc_ln | LayerNorm | 32 5 | model.fc_top | Linear | 17 6 | model.dropout | Dropout | 0 7 | loss | HuberLoss | 0 --------------------------------------------- 7.6 K Trainable params 0 Non-trainable params 7.6 K Total params 0.030 Total estimated model params size (MB)
5. EVALUATION¶
For evaluating the model, we employ some of the most common metrics for regression tasks, available in the TorchMetrics library [Detlefsen, N. S., 2022]:
- Mean Squared Error (MSE): Signal fidelity measure that measures the average squared difference between the estimated values and the actual value. Range: [0, +inf], values closer to zero indicate better performance.
- Pearson Correlation Coefficient: Statistical measure of the linear correlation between two variables. Range: [-1,1], higher values indicate higher concordance between variables.
- Spearman Correlation Coefficient: Statistical measure of the strength of a monotonic relationship between paired data. Range: [-1,1], higher values indicate higher concordance between variables.
This list has been defined in the config file, evaluation section.
regressor.test()
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Test metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_MeanSquaredError 0.012496540322899818
test_PearsonCorrCoef 0.2792527973651886
test_SpearmanCorrCoef 0.4161713421344757
test_loss_epoch 0.005564543418586254
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Verification of results¶
For a better assessment of the model's performance, we visualise the predictions of our model in contraposition to the real values of the events that conform the test set (worldwide earthquakes comprised between 2015-2021).
regressor.config['evaluation']['xai']['activate'] = False
output_data= regressor.inference(subset='test')
scaler = regressor.test_loader.dataset.get_target_scaler()
all_pred = []
all_labels = []
for pred, label in zip(output_data['outputs'], output_data['labels']):
pred = np.exp(scaler.inverse_transform(3*pred.reshape(-1, 1))) - 1
label = np.exp(scaler.inverse_transform(3*label.reshape(-1, 1))) - 1
#pred = np.exp(scaler.inverse_transform(pred.reshape(-1, 1))) - 1
#label = np.exp(scaler.inverse_transform(label.reshape(-1, 1))) - 1
all_pred.append(np.squeeze(pred))
all_labels.append(np.squeeze(label))
print('Prediction: ', round(pred.item(),10)*100, ', Groundtruth: ', round(label.item(),10)*100)
plt.plot(all_pred, label='preds')
plt.plot(all_labels, label='GT')
plt.legend()
Infering Dataloader: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 65/65 [00:00<00:00, 123.90it/s]
Prediction: 0.00097046 , Groundtruth: 5.95e-06 Prediction: 0.00636343 , Groundtruth: 0.00026736 Prediction: 0.00194929 , Groundtruth: 0.03340641 Prediction: 0.00642674 , Groundtruth: 0.0 Prediction: 0.00026819 , Groundtruth: 1.3e-06 Prediction: 0.01012827 , Groundtruth: 0.00026958 Prediction: 0.0098428 , Groundtruth: 0.00047497999999999996 Prediction: 0.00150219 , Groundtruth: 0.00110324 Prediction: 0.00603101 , Groundtruth: 0.00064717 Prediction: 0.0034506199999999997 , Groundtruth: 0.00043580999999999996 Prediction: 0.008312990000000001 , Groundtruth: 0.00012552 Prediction: 0.00198607 , Groundtruth: 0.0035738999999999996 Prediction: 0.00433769 , Groundtruth: 0.00654365 Prediction: 0.01026217 , Groundtruth: 0.01323092 Prediction: 0.00351557 , Groundtruth: 0.026366729999999998 Prediction: 0.015814170000000002 , Groundtruth: 0.00337781 Prediction: 0.0033751899999999997 , Groundtruth: 0.00015527 Prediction: 0.014835449999999998 , Groundtruth: 0.1694356 Prediction: 0.01473533 , Groundtruth: 0.05055859 Prediction: 0.0016438499999999999 , Groundtruth: 1.034e-05 Prediction: 0.0060726700000000005 , Groundtruth: 0.021995709999999998 Prediction: 0.00077359 , Groundtruth: 7.127e-05 Prediction: 0.00040242 , Groundtruth: 2.2300000000000002e-06 Prediction: 0.01149356 , Groundtruth: 0.00454709 Prediction: 0.015187039999999999 , Groundtruth: 1.092e-05 Prediction: 0.00811506 , Groundtruth: 0.0021135999999999998 Prediction: 0.00100408 , Groundtruth: 1.778e-05 Prediction: 0.0051971000000000005 , Groundtruth: 0.0005264299999999999 Prediction: 0.01146134 , Groundtruth: 0.00306652 Prediction: 0.00164437 , Groundtruth: 4.325e-05 Prediction: 0.00382826 , Groundtruth: 4.52e-06 Prediction: 0.00488591 , Groundtruth: 6e-07 Prediction: 0.006506509999999999 , Groundtruth: 0.00025791 Prediction: 0.00664641 , Groundtruth: 0.035108719999999996 Prediction: 0.00416623 , Groundtruth: 0.0028067599999999997 Prediction: 0.0039426 , Groundtruth: 0.00122546 Prediction: 0.01088672 , Groundtruth: 0.01285937 Prediction: 0.00295124 , Groundtruth: 1.628e-05 Prediction: -0.0008002899999999999 , Groundtruth: 0.00363181 Prediction: 0.00587447 , Groundtruth: 0.00071623 Prediction: 0.00259355 , Groundtruth: 0.00051957 Prediction: -0.0022530700000000002 , Groundtruth: 7.7e-06 Prediction: 0.01150704 , Groundtruth: 0.00023435 Prediction: 0.00855302 , Groundtruth: 0.01154638 Prediction: -0.00718415 , Groundtruth: 4.304e-05 Prediction: 0.00830927 , Groundtruth: 7.308e-05 Prediction: 0.00358255 , Groundtruth: 0.00028531 Prediction: 0.00768201 , Groundtruth: 0.00209244 Prediction: 0.0037029000000000003 , Groundtruth: 0.013843610000000001 Prediction: 0.01753345 , Groundtruth: 0.00102128 Prediction: 0.0014042 , Groundtruth: 1.3e-07 Prediction: 0.00462666 , Groundtruth: 0.00152468 Prediction: -0.00582752 , Groundtruth: 0.00039185 Prediction: 0.0139784 , Groundtruth: 0.0009567 Prediction: 0.00591156 , Groundtruth: 0.00104522 Prediction: 0.01028509 , Groundtruth: 0.00373315 Prediction: 0.00947045 , Groundtruth: 0.01625901 Prediction: 0.00727966 , Groundtruth: 0.0028929100000000003 Prediction: 0.01104137 , Groundtruth: 0.00099832 Prediction: 0.00405477 , Groundtruth: 0.020151119999999998 Prediction: 0.00717497 , Groundtruth: 0.00219026 Prediction: 0.0076355500000000005 , Groundtruth: 3.598e-05 Prediction: 0.0017737800000000002 , Groundtruth: 0.00022579 Prediction: 0.01163593 , Groundtruth: 0.00125373 Prediction: 0.01259306 , Groundtruth: 0.01562514
<matplotlib.legend.Legend at 0x7f6bf5ca5d00>
6. EXPLAINABLE AI¶
To complete the assessment of our model, we activate the optional feature of the eXplainable AI (XAI) module. This module allows us to go beyond the intuitive understanding of the model's performance and gain insight into its inner workings. The backbone of this module is Python’s Captum library (https://captum.ai/) [Kokhlikyan et al., 2020], a well-known open-source library for model interpretability built on PyTorch, that efficiently implements the vast majority of the attribution methods proposed in the literature.
In this example, we select Saliency [Simonyan, 2013], one of the simplest attribution methods. This method is based on the computation of the gradient of the output with respect to the input.
regressor.config['evaluation']['xai']['activate'] = True
regressor.config['evaluation']['xai']['type'] = 'Saliency'
xai_output= regressor.inference(subset='test')
Infering Dataloader: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 65/65 [00:00<00:00, 124.70it/s] Explaining Dataloader: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 65/65 [00:00<00:00, 117.64it/s] Visualizing explanations: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 65/65 [03:01<00:00, 2.80s/it]
This method will provide two types of visualizations per event. The first one shows the average importance over all input timesteps of each of the input variables for our task. In this particular event (2015-0465-HRV), the most relevant feature is tp_urban which is the total daily precipitation (tp) affecting urban areas.
# For one event plot side by side the XAI images obtained
event_name= xai_output['event_names'][0]
path= list((Path(regressor.config['save_path']) / 'xai').glob(f'{event_name}*.png'))
# First visualization result for the XAI module
img= np.asarray(Image.open(path[0]))
plt.figure(figsize=(15,15))
plt.imshow(img)
plt.axis(False);
plt.title(path[0].name)
plt.show()
The second visualization shows the attribution information on a breakdown per time step. The first two rows show the GT and the predicted values. Notice that the predicted values are negative (instead of being a percentage), this is due to a normalisation that was applied to train the architecture. The rest of the rows show the input variables as a blue line, as well as the attribution of each variable at each timestep (i.e., how much it has contributed to the predicted output).
# Second visualization result for the XAI module
img= np.asarray(Image.open(path[1]))
plt.figure(figsize=(15,15))
plt.imshow(img)
plt.axis(False);
plt.title(path[1].name)
plt.show()