#####################################################################
# DOC #
#####################################################################
"""
@author: F. Ramognino <federico.ramognino@polimi.it>
Last update: 12/06/2023
"""
#####################################################################
# IMPORT #
#####################################################################
from __future__ import annotations
from typing import Iterable, Literal, Callable
from enum import StrEnum
import pandas as pd
import numpy as np
from pandas import DataFrame
from libICEpost.src.base.Functions.typeChecking import checkType, checkArray, checkMap
from libICEpost.src.base.Utilities import Utilities
from scipy.interpolate import RegularGridInterpolator
import matplotlib as mpl
import matplotlib.colors as mcolors
from matplotlib import pyplot as plt
import matplotlib.ticker
import itertools
import warnings
from .BaseTabulation import BaseTabulation
#####################################################################
# AUXILIARY CLASSES #
#####################################################################
[docs]
class _OoBMethod(StrEnum):
"""Out-of-bounds methods"""
extrapolate = "extrapolate"
nan = "nan"
fatal = "fatal"
[docs]
class TabulationAccessWarning(Warning):
"""Warning from Tabulation access"""
pass
#############################################################################
# AUXILIARY FUNCTIONS #
#############################################################################
[docs]
def toPandas(table:Tabulation) -> DataFrame:
"""
Convert an instance of Tabulation to a pandas.DataFrame with all the points stored in the tabulation.
The columns are the input variables plus "output", which stores the sampling points.
Args:
table (Tabulation): The table to convert to a dataframe.
Returns:
DataFrame
"""
checkType(table, Tabulation, "table")
# Create the dataframe
df = DataFrame({"output":table._data.flatten(), **{f:[0.0]*table.size for f in table.ranges}}, columns=table.order+["output"])
#Populate
inputs = itertools.product(*[table.ranges[f] for f in table.order])
for ii, ipt in enumerate(inputs):
df.iloc[ii,:-1] = list(ipt)
return df
#Alias
to_pandas = toPandas
#############################################################################
[docs]
def insertDimension(table:Tabulation, variable:str, value:float, index:int=None, inplace:bool=False) -> Tabulation|None:
"""
Insert an axis to the dimension-set of the table with a single value.
This is useful to merge two tables with respect to an additional variable.
Args:
table (Tabulation): The table to modify.
variable (str): The name of the variable to insert.
value (float): The value for the range of the corresponding variable.
index (int, optional): The index where to insert the variable in nesting order. If None, the variable is appended at the end. Defaults to None.
inplace (bool, optional): If True, the operation is performed in-place. Defaults to False.
Returns:
Tabulation|None: The table with the inserted dimension if inplace is False, None otherwise.
Example:
Create a table with two variable:
```
>>> tab1 = Tabulation([1, 2, 3, 4], {"x":[0, 1], "y":[0, 1]}, ["x", "y"])
>>> tab1.insertDimension("z", 0.0, 1)
>>> tab1.ranges
{"x":[0, 1], "z":[0.0], "y":[0, 1]}
```
Create a second table with the same variables:
```
>>> tab2 = Tabulation([5, 6, 7, 8], {"x":[0, 1], "y":[0, 1]}, ["x", "y"])
>>> tab2.insertDimension("z", 1.0, 1)
>>> tab2.ranges
{"x":[0, 1], "z":[1.0], "y":[0, 1]}
```
Concatenate the two tables:
```
>>> tab1.concat(tab2, inplace=True)
>>> tab1.ranges
{"x":[0, 1], "z":[0.0, 1.0], "y":[0, 1]}
```
"""
if not inplace:
tab = table.copy()
tab.insertDimension(variable, value, index, inplace=True)
return tab
#Check arguments
table.checkType(variable, str, "variable")
table.checkType(value, float, "value")
table.checkType(index, int, "index", allowNone=True)
table.checkType(inplace, bool, "inplace")
if index is None:
index = len(table.order)
#Check if variable already exists
if variable in table.order:
raise ValueError(f"Variable '{variable}' already exists in the table.")
#Check index
if not (0 <= index <= table.ndim):
raise ValueError(f"Index out of range. Must be between 0 and {table.ndim}.")
#Insert variable
table._order.insert(index, variable)
table._ranges[variable] = [value]
table._data = table._data.reshape([len(table._ranges[f]) for f in table.order])
table._createInterpolator()
#############################################################################
[docs]
def concat(table:Tabulation, *tables:Tabulation, inplace:bool=False, fillValue:float=None, overwrite:bool=False) -> Tabulation|None:
"""
Extend the table with the data of other tables. The tables must have the same variables but
not necessarily in the same order. The data of the second table is appended to the data
of the first table, preserving the order of the variables.
If fillValue is not given, the ranges of the second table must be consistent with those
of the first table in the variables that are not concatenated. If fillValue is given, the
missing sampling points are filled with the given value.
Args:
table (Tabulation): The table to which the data is appended.
*tables (Tabulation): The tables to append.
inplace (bool, optional): If True, the operation is performed in-place. Defaults to False.
fillValue (float, optional): The value to fill missing sampling points. Defaults to None.
overwrite (bool, optional): If True, overwrite the data of the first table with the data
of the second table in overlapping regions. Otherwise raise an error. Defaults to False.
Returns:
Tabulation|None: The concatenated table if inplace is False, None otherwise.
"""
#Check arguments
checkType(table, Tabulation, "table")
checkArray(tables, Tabulation, "tables")
checkType(inplace, bool, "inplace")
checkType(overwrite, bool, "overwrite")
if not fillValue is None:
checkType(fillValue, float, "fillValue")
if not inplace:
tab = table.copy()
concat(tab, *tables, inplace=True, fillValue=fillValue, overwrite=overwrite)
return tab
order = table.order
ranges = table.ranges
for ii, tab in enumerate(tables):
#Check compatibility
if not (set(order) == set(tab.order)):
raise ValueError(f"Tables must have the same input variables to concatenate (table[{ii}] incompatible).")
#Merge ranges
ranges = {f:sorted(set(ranges[f]).union(set(tab.ranges[f]))) for f in order}
data = np.zeros([len(ranges[f]) for f in order])*float("nan") #Create empty data
if not fillValue is None: data *= fillValue #Fill with value
written = np.zeros_like(data, dtype=bool) #Check if data has been written
for tab in [table, *tables]:
r = {f:list(tab.ranges[f]) for f in order}
o = tab.order
for jj in itertools.product(*[range(len(ranges[f])) for f in order]):
#Get index of jj in the table
index = [
r[f].index(ranges[f][jj[order.index(f)]]) if ranges[f][jj[order.index(f)]] in r[f] else None
for ii, f in enumerate(o)]
if None in index:
continue #Not in this table
if not overwrite and written[jj]:
raise ValueError(f"Data already written at index {jj}. Cannot overwrite.")
data[*jj] = tab._data[*index]
written[*jj] = True
#Check for missing sampling points
if fillValue is None and not np.all(written):
raise ValueError("Missing sampling points in the concatenated tables. Cannot concatenate without 'fillValue' argument.")
#Create new table
table._ranges = {v:np.array(ranges[v]) for v in order}
table._data = data
table._createInterpolator()
#Alias
merge = concat
#############################################################################
[docs]
def squeeze(table:Tabulation, *, inplace:bool=False) -> Tabulation|None:
"""
Remove dimensions with only 1 data-point.
Args:
table (Tabulation): The table to squeeze.
inplace (bool, optional): If True, the operation is performed in-place. Defaults to False.
Returns:
Tabulation|None: The squeezed tabulation if inplace is False, None otherwise.
"""
if not inplace:
tab = table.copy()
tab.squeeze(inplace=True)
return tab
#Find dimensions with more than one data-point
dimsToKeep = []
for ii, dim in enumerate(table.shape):
if dim > 1:
dimsToKeep.append(ii)
#Extract data
table._order = list(map(table.order.__getitem__, dimsToKeep))
table._ranges = {var:table._ranges[var] for var in table._order}
table._data = table._data.squeeze()
#Update interpolator
table._createInterpolator()
#########################################################################
[docs]
def sliceTable(table:Tabulation, *, slices:Iterable[slice|Iterable[int]|int]=None, ranges:dict[str,float|Iterable[float]]=None, inplace=False, **argv) -> Tabulation|None:
"""
Extract a table with sliced datase. Can access in two ways:
1) by slicer
2) sub-set of interpolation points. Keyword arguments also accepred.
Args:
table (Tabulation): The table
ranges (dict[str,float|Iterable[float]], optional): Ranges of sliced table. Defaults to None.
slices (Iterable[slice|Iterable[int]|int]): The slicers for each input-variable.
inplace (bool, optional): If True, the operation is performed in-place. Defaults to False.
Returns:
Tabulation|None: The sliced table if inplace is False, None otherwise.
"""
checkType(table, Tabulation, "table")
checkType(inplace, bool, "inplace")
#Code implemented for inplace
if not inplace:
tab = table.copy()
tab.slice(slices=slices, ranges=ranges, inplace=True, **argv)
return tab
#Update ranges with keyword arguments
ranges = dict() if ranges is None else ranges
ranges.update(argv)
if len(ranges) == 0:
ranges = None
if (slices is None) and (ranges is None):
raise ValueError("Must provide either 'slices' or 'ranges' to slice the table.")
elif not(slices is None) and not(ranges is None):
raise ValueError("Cannot provide both 'slices' and 'ranges' to slice the table.")
#Swith access
if not slices is None: #By slices
checkType(slices, Iterable, "slices")
if isinstance(slices, str):
raise TypeError("Type mismatch. Attempting to slice with entry of type 'str'.")
slices = list(slices) #Cast to list (mutable)
#Check types
if not(len(slices) == len(table.order)):
raise IndexError("Given {} slices, while table has {} variables ({}).".format(len(slices), len(table.order), table.order))
for ii, ss in enumerate(slices):
if isinstance(ss, slice):
#Convert to list of indexes
slices[ii] = list(range(*ss.indices(table.shape[ii])))
elif isinstance(ss,(int, np.integer)):
if ss >= table.shape[ii]:
raise IndexError(f"Index out of range for slices[{ii}] ({ss} >= {table.shape[ii]})")
elif isinstance(ss, Iterable):
checkArray(ss, (int, np.integer), f"slices[{ii}]")
slices[ii] = sorted(ss) #Sort
for jj,ind in enumerate(ss): #Check range
if ind >= table.shape[ii]:
checkType(ind, int, f"slices[{ii}][{jj}]")
raise IndexError(f"Index out of range for variable {ii}:{table.order[ii]} ({ind} >= {table.shape[ii]})")
else:
raise TypeError("Type mismatch. Attempting to slice with entry of type '{}'.".format(ss.__class__.__name__))
#Create ranges:
order = table.order
ranges = dict()
for ii, Slice in enumerate(slices):
ranges[order[ii]] = np.array(table.ranges[order[ii]][Slice])
#Create slicing table:
slTab = np.ix_(*tuple(slices))
data = table._data[slTab]
#Update table
table._data = data
table._ranges = ranges
table._createInterpolator()
elif not ranges is None: #By ranges
#Start from the original ranges
newRanges = table.ranges
#Check arguments:
checkMap(ranges, str, (Iterable, float), entryName="ranges")
for rr in ranges:
if isinstance(ranges[rr], float):
ranges[rr] = [ranges[rr]]
for ii in ranges[rr]:
if not(ii in table.ranges[rr]):
raise ValueError(f"Sampling value '{ii}' not found in range for variable '{rr}' with points:\n{table.ranges[rr]}")
#Update ranges
newRanges.update(**ranges)
#Create slicers to access by index
slices = []
for ii, item in enumerate(table.order):
slices.append(np.where(np.isin(table.ranges[item], newRanges[item]))[0])
#Slice by index
table.slice(slices=tuple(slices), inplace=True)
#############################################################################
[docs]
def clipTable(table:Tabulation, ranges:dict[str,tuple[float|None,float|None]]=None, *, inplace:bool=False, **kwargs) -> Tabulation|None:
"""
Clip the table to the given ranges. The ranges are given as a dictionary with the
variable names as keys and a tuple with the minimum and maximum values.
Args:
table (Tabulation): The table to clip.
ranges (dict[str,tuple[float|None,float|None]], optional): The ranges to clip for each input-variable. If min or max is None, the range is unbounded.
inplace (bool, optional): If True, the operation is performed in-place. Defaults to False.
**kwargs: Can access also by keyword arguments.
Returns:
Tabulation|None: The clipped table if inplace is False, None otherwise.
"""
checkType(table, Tabulation, "table")
checkType(inplace, bool, "inplace")
checkType(ranges, dict, "ranges", allowNone=True)
if not inplace:
tab = table.copy()
tab.clip(ranges, inplace=True, **kwargs)
return tab
#Update ranges with keyword arguments
ranges = dict() if ranges is None else ranges
for kw in kwargs:
if kw in ranges:
raise ValueError(f"Keyword argument '{kw}' is already present in 'ranges'.")
ranges.update(kwargs)
if len(ranges) == 0:
ranges = None
if ranges is None:
raise ValueError("Must provide 'ranges' to clip the table.")
#Check arguments
table.checkMap(ranges, str, tuple, entryName="ranges")
#Compute clipped ranges
newRanges = {}
for var in ranges:
if not var in table.order:
raise ValueError(f"Variable '{var}' not found in table.")
if not len(ranges[var]) == 2:
raise ValueError(f"Invalid range for variable '{var}'. Must be a tuple with two values (min, max).")
if not (ranges[var][0] is None) or not (ranges[var][1] is None):
newRanges[var] = table._ranges[var]
if not (ranges[var][0] is None):
newRanges[var] = newRanges[var][newRanges[var] >= ranges[var][0]]
if not (ranges[var][1] is None):
newRanges[var] = newRanges[var][newRanges[var] <= ranges[var][1]]
if any([len(newRanges[var]) == 0 for var in newRanges]):
raise ValueError("Clipping would result in empty table (zero-size range).")
#Clip
for ii, var in enumerate(table.order):
if var in newRanges:
table._data = table._data.take(np.asarray(np.isin(table._ranges[var], newRanges[var])).nonzero()[0], axis=ii)
table._ranges[var] = newRanges[var]
#############################################################################
#Plot:
[docs]
def plotTable( table:Tabulation,
x:str, c:str, iso:dict[str,float]=None,
*,
ax:plt.Axes=None,
colorMap:str=None,
xlabel:str=None,
ylabel:str=None,
clabel:str=None,
title:str=None,
xlim:tuple[float|None,float|None]=None,
ylim:tuple[float|None,float|None]=None,
clim:tuple[float|None,float|None]=None,
figsize:tuple[float]=None,
**kwargs) -> plt.Axes:
"""
Plot a table in a 2D plot with a color-map.
Args:
table (Tabulation): The table to plot.
x (str): The x-axis variable.
c (str): The color variable.
iso (dict[str,float], optional): The iso-values to plot. If the table has only 2 variables, this argument is not needed. Defaults to None.
ax (plt.Axes, optional): The axis to plot on. Defaults to None.
colorMap (str, optional): The color-map to use. Defaults to None. Equivalent keys are [`cmap`, `colormap`]
xlabel (str, optional): The x-axis label. Defaults to None. Equivalent keys are [`x_label`, `xLabel`]
ylabel (str, optional): The y-axis label. Defaults to None. Equivalent keys are [`y_label`, `yLabel`]
clabel (str, optional): The color-bar label. Defaults to None. Equivalent keys are [`c_label`, `cLabel`]
title (str, optional): The title of the plot. Defaults to None.
xlim (tuple[float], optional): The x-axis limits. Defaults to None. Equivalent keys are [`x_lim`, `xLim`]
ylim (tuple[float], optional): The y-axis limits. Defaults to None. Equivalent keys are [`y_lim`, `yLim`]
clim (tuple[float], optional): The color-bar limits. Defaults to None. Equivalent keys are [`c_lim`, `cLim`]
figsize (tuple[float], optional): The size of the figure. Defaults to None.
**kwargs: Additional arguments to pass to the plot
Returns:
plt.Axes: The axis of the plot.
"""
#Check for equivalent keys
equivalentKeys:dict[str,list[str]] = {
"xlabel":["xlabel", "x_label", "xLabel"],
"ylabel":["ylabel", "y_label", "yLabel"],
"clabel":["clabel", "c_label", "cLabel"],
"xlim":["xlim", "x_lim", "xLim"],
"ylim":["ylim", "y_lim", "yLim"],
"clim":["clim", "c_lim", "cLim"],
"colorMap":["colorMap", "cmap", "colormap"],
}
fullkwargs = {**kwargs}
if xlabel is not None: fullkwargs["xlabel"] = xlabel
if ylabel is not None: fullkwargs["ylabel"] = ylabel
if clabel is not None: fullkwargs["clabel"] = clabel
if xlim is not None: fullkwargs["xlim"] = xlim
if ylim is not None: fullkwargs["ylim"] = ylim
if clim is not None: fullkwargs["clim"] = clim
if colorMap is not None: fullkwargs["colorMap"] = colorMap
foundKeys = set(fullkwargs.keys()).intersection(sum(equivalentKeys.values(), start=[]))
#Check for multiple entries that are equivalent
keyMap:dict[str,list] = {v:[] for v in equivalentKeys.keys()}
for key in foundKeys:
for k in equivalentKeys:
if key in equivalentKeys[k]:
keyMap[k].append(key)
for key in keyMap:
if len(keyMap[key]) > 1:
raise ValueError(f"Key '{key}' found multiple times in kwargs: {keyMap[key]}")
#Set equivalent keys
xlabel = fullkwargs[keyMap["xlabel"][0]] if len(keyMap["xlabel"]) > 0 else None
ylabel = fullkwargs[keyMap["ylabel"][0]] if len(keyMap["ylabel"]) > 0 else None
clabel = fullkwargs[keyMap["clabel"][0]] if len(keyMap["clabel"]) > 0 else None
xlim = fullkwargs[keyMap["xlim"][0]] if len(keyMap["xlim"]) > 0 else (None, None)
ylim = fullkwargs[keyMap["ylim"][0]] if len(keyMap["ylim"]) > 0 else (None, None)
clim = fullkwargs[keyMap["clim"][0]] if len(keyMap["clim"]) > 0 else (None, None)
colorMap = fullkwargs[keyMap["colorMap"][0]] if len(keyMap["colorMap"]) > 0 else None
#Remove from kwargs
for key in foundKeys:
if key in kwargs: kwargs.pop(key)
#Check arguments
checkType(table, Tabulation, "table")
checkType(x, str, "x")
checkType(c, str, "c")
checkType(iso, dict, "iso", allowNone=True)
if iso is None: iso = dict()
checkMap(iso, str, float, "iso")
checkType(ax, plt.Axes, "ax", allowNone=True)
checkType(colorMap, str, "colorMap", allowNone=True)
checkType(xlabel, str, "xlabel", allowNone=True)
checkType(ylabel, str, "ylabel", allowNone=True)
checkType(clabel, str, "clabel", allowNone=True)
checkType(title, str, "title", allowNone=True)
checkType(xlim, tuple, "xlim")
checkType(ylim, tuple, "ylim")
checkType(clim, tuple, "clim")
checkType(figsize, tuple, "figsize", allowNone=True)
#Check variables
if not x in table.order:
raise ValueError(f"Variable '{x}' not found in table.")
if not c in table.order:
raise ValueError(f"Variable '{c}' not found in table.")
#Check iso-values
for f in iso:
if not f in table.order:
raise ValueError(f"Variable '{f}' not found in table.")
if not iso[f] in table.ranges[f]:
raise ValueError(f"Iso-value for variable '{f}' not found in the table.")
if not (set(table.order) == set(iso.keys()).union({x, c})):
raise ValueError("Iso-values must be given for all but x and c variables ({}).".format(", ".join(set(table.order) - set(iso.keys()).union({x, c}))))
#Create the axis
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
#Default plot style
if not any(s in kwargs for s in ["marker", "m"]):
kwargs.update(marker="o")
if not any(s in kwargs for s in ["linestyle", "ls"]):
kwargs.update(linestyle="--")
#Slice the data-set
tab = table.slice(ranges={f:[iso[f]] for f in iso}) if (len(iso) > 0) else table
#Update color-bar limits
if clim[0] is None:
clim = (tab.ranges[c].min(), clim[1])
if clim[1] is None:
clim = (clim[0], tab.ranges[c].max())
#Plot
norm = mcolors.Normalize(vmin=clim[0], vmax=clim[1])
sm = plt.cm.ScalarMappable(cmap=colorMap, norm=norm)
cmap = sm.cmap
sm.set_array([])
for ii, val in enumerate(tab.ranges[c]):
data = tab.slice(ranges={c:[val]})
ax.plot(
data.ranges[x],
data.data.flatten(),
color=cmap(norm(val)),
**kwargs)
#Color-bar
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label(clabel if not clabel is None else c)
#Labels
ax.set_xlabel(xlabel if not xlabel is None else x)
ax.set_ylabel(ylabel)
ax.set_title(title if not title is None else " - ".join([f"{f}={iso[f]}" for f in iso]))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
return ax
#############################################################################
[docs]
def plotTableHeatmap( table:Tabulation,
x:str, y:str, iso:dict[str,float]=None,
*,
ax:plt.Axes=None,
colorMap:str=None,
xlabel:str=None,
ylabel:str=None,
clabel:str=None,
title:str=None,
xlim:tuple[float|None,float|None]=None,
ylim:tuple[float|None,float|None]=None,
clim:tuple[float|None,float|None]=None,
figsize:tuple[float,float]=None,
isolines_kwargs:dict[str,object]=None,
**kwargs) -> plt.Axes:
"""
Plot a table in a 2D plot with a color-map.
Args:
table (Tabulation): The table to plot.
x (str): The x-axis variable.
y (str): The y-axis variable.
iso (dict[str,float], optional): The iso-values to plot. If the table has only 3 variables, this argument is not needed. Defaults to None.
ax (plt.Axes, optional): The axis to plot on. Defaults to None.
colorMap (str, optional): The color-map to use. Defaults to None. Equivalent keys are [`cmap`, `colormap`]
xlabel (str, optional): The x-axis label. Defaults to None. Equivalent keys are [`x_label`, `xLabel`]
ylabel (str, optional): The y-axis label. Defaults to None. Equivalent keys are [`y_label`, `yLabel`]
clabel (str, optional): The color-bar label. Defaults to None. Equivalent keys are [`c_label`, `cLabel`]
title (str, optional): The title of the plot. Defaults to None.
xlim (tuple[float|None,float|None], optional): The x-axis limits. Defaults to None. Equivalent keys are [`x_lim`, `xLim`]
ylim (tuple[float|None,float|None], optional): The y-axis limits. Defaults to None. Equivalent keys are [`y_lim`, `yLim`]
clim (tuple[float|None,float|None], optional): The color-bar limits. Defaults to None. Equivalent keys are [`c_lim`, `cLim`]
figsize (tuple[float,float], optional): The size of the figure. Defaults to None.
isolines_kwargs (dict[str,object], optional): The keyword arguments to pass to contour() for the isolines. Defaults to None.
**kwargs: Additional arguments to pass to the contourf plot.
Returns:
plt.Axes: The axis of the plot.
"""
#Check for equivalent keys
equivalentKeys:dict[str,list[str]] = {
"xlabel":["xlabel", "x_label", "xLabel"],
"ylabel":["ylabel", "y_label", "yLabel"],
"clabel":["clabel", "c_label", "cLabel"],
"xlim":["xlim", "x_lim", "xLim"],
"ylim":["ylim", "y_lim", "yLim"],
"clim":["clim", "c_lim", "cLim"],
"colorMap":["colorMap", "cmap", "colormap"],
}
fullkwargs = {**kwargs}
if xlabel is not None: fullkwargs["xlabel"] = xlabel
if ylabel is not None: fullkwargs["ylabel"] = ylabel
if clabel is not None: fullkwargs["clabel"] = clabel
if xlim is not None: fullkwargs["xlim"] = xlim
if ylim is not None: fullkwargs["ylim"] = ylim
if clim is not None: fullkwargs["clim"] = clim
if colorMap is not None: fullkwargs["colorMap"] = colorMap
foundKeys = set(fullkwargs.keys()).intersection(sum(equivalentKeys.values(), start=[]))
#Check for multiple entries that are equivalent
keyMap:dict[str,list] = {v:[] for v in equivalentKeys.keys()}
for key in foundKeys:
for k in equivalentKeys:
if key in equivalentKeys[k]:
keyMap[k].append(key)
for key in keyMap:
if len(keyMap[key]) > 1:
raise ValueError(f"Key '{key}' found multiple times in kwargs: {keyMap[key]}")
#Set equivalent keys
xlabel = fullkwargs[keyMap["xlabel"][0]] if len(keyMap["xlabel"]) > 0 else None
ylabel = fullkwargs[keyMap["ylabel"][0]] if len(keyMap["ylabel"]) > 0 else None
clabel = fullkwargs[keyMap["clabel"][0]] if len(keyMap["clabel"]) > 0 else None
xlim = fullkwargs[keyMap["xlim"][0]] if len(keyMap["xlim"]) > 0 else (None, None)
ylim = fullkwargs[keyMap["ylim"][0]] if len(keyMap["ylim"]) > 0 else (None, None)
clim = fullkwargs[keyMap["clim"][0]] if len(keyMap["clim"]) > 0 else (None, None)
colorMap = fullkwargs[keyMap["colorMap"][0]] if len(keyMap["colorMap"]) > 0 else None
#Remove from kwargs
for key in foundKeys:
if key in kwargs: kwargs.pop(key)
#Keyword arguments for isolines
if isolines_kwargs is None: isolines_kwargs = dict()
if not any(k in isolines_kwargs for k in ["c", "colors"]):
isolines_kwargs.update(colors="black")
if "norm" in kwargs:
isolines_kwargs.update(norm=kwargs["norm"])
#Check arguments
checkType(table, Tabulation, "table")
checkType(x, str, "x")
checkType(y, str, "y")
checkType(iso, dict, "iso", allowNone=True)
if iso is None: iso = dict()
checkMap(iso, str, float, "iso")
checkType(ax, plt.Axes, "ax", allowNone=True)
checkType(colorMap, str, "colorMap", allowNone=True)
checkType(xlabel, str, "xlabel", allowNone=True)
checkType(ylabel, str, "ylabel", allowNone=True)
checkType(clabel, str, "clabel", allowNone=True)
checkType(title, str, "title", allowNone=True)
checkType(xlim, tuple, "xlim")
checkType(ylim, tuple, "ylim")
checkType(clim, tuple, "clim")
checkType(figsize, tuple, "figsize", allowNone=True)
checkMap(isolines_kwargs, str, object, "isolines_kwargs")
#Check variables
if not x in table.order:
raise ValueError(f"Variable '{x}' not found in table.")
if not y in table.order:
raise ValueError(f"Variable '{y}' not found in table.")
#Check iso-values
for f in iso:
if not f in table.order:
raise ValueError(f"Variable '{f}' not found in table.")
if not iso[f] in table.ranges[f]:
raise ValueError(f"Iso-value for variable '{f}' not found in the table.")
if not (set(table.order) == set(iso.keys()).union({x, y})):
raise ValueError("Iso-values must be given for all but x and y variables ({}).".format(", ".join(set(table.order) - set(iso.keys()).union({x, y}))))
#Create the axis
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
#Slice the data-set
tab = table.slice(ranges={f:[iso[f]] for f in iso}) if (len(iso) > 0) else table
tab.squeeze(inplace=True)
tab.order = [x, y]
#Update color-bar limits
if clim[0] is None:
clim = (np.min(tab._data), clim[1])
if clim[1] is None:
clim = (clim[0], np.max(tab._data))
#Plot
cs = ax.contourf(
tab.ranges[x],
tab.ranges[y],
tab.data.T,
levels=np.linspace(clim[0], clim[1], 256),
cmap=colorMap,
**kwargs)
#Color-bar
import matplotlib.cm as cm
sm = cm.ScalarMappable(norm=kwargs.get("norm"), cmap=colorMap)
sm.norm.vmin = clim[0]
sm.norm.vmax = clim[1]
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label(clabel)
#Isolines
cs = ax.contour(
tab.ranges[x],
tab.ranges[y],
tab._data.T,
levels=cbar.get_ticks(),
vmin=clim[0],
vmax=clim[1],
**isolines_kwargs)
#If one day we want to add labels to the isolines (quite ugly)
# ax.clabel(cs, cs.levels, fmt=cbar.formatter if levelsfmt is None else levelsfmt, fontsize=levelssize)
#Labels
ax.set_xlabel(xlabel if not xlabel is None else x)
ax.set_ylabel(ylabel)
ax.set_title(title if not title is None else " - ".join([f"{f}={iso[f]}" for f in iso]))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
return ax
#############################################################################
# MAIN CLASSES #
#############################################################################
#Class used for storing and handling a generic tabulation:
[docs]
class Tabulation(BaseTabulation):
"""
Class used for storing and handling a tabulation from a structured grid in an n-dimensional space of input-variables.
"""
_ranges:dict[str,np.ndarray]
"""The sampling points for each input-variable"""
_data:np.ndarray
"""The n-dimensional dataset of the table"""
_outOfBounds:_OoBMethod
"""How to handle out-of-bounds access to table."""
_interpolator:RegularGridInterpolator
"""The interpolator."""
#########################################################################
#Class methods:
[docs]
@classmethod
def from_pandas(cls, data:DataFrame, order:Iterable[str], field:str, **kwargs) -> Tabulation:
"""
Construct a tabulation from a pandas.DataFrame with n+x columns where n is len(order).
Args:
data (DataFrame): The data-frame to use.
order (Iterable[str]): The order in which the input variables are nested.
field (str): The name of the field containing the output values.
**kwargs: Additional arguments to pass to the constructor.
Returns:
Tabulation: The tabulation.
"""
#Argument checking:
cls.checkType(data, DataFrame, "data")
cls.checkArray(order, str, "order")
cls.checkType(field, str, "field")
if not len(data.columns) > len(order):
raise ValueError("DataFrame must have n+x columns, where n is the number of input variables.")
for f in order:
if not f in data.columns:
raise ValueError(f"Variable '{f}' not found in DataFrame.")
if not field in data.columns:
raise ValueError(f"Field '{field}' not found in DataFrame.")
#Create ranges:
ranges = {}
for f in order:
ranges[f] = np.array(sorted(data[f].unique()))
#Sort data in the correct order
data_sorted = data.sort_values(by=order, ascending=True, ignore_index=True)
#Check that all combinations of input variables are present and in the correct order
samplingPoints = itertools.product(*[ranges[f] for f in order])
for ii, sp in enumerate(samplingPoints):
for jj, f in enumerate(order):
if not data_sorted.iloc[ii][f] == sp[jj]:
raise ValueError(f"Data not consistent with sampling points. Expected {sp} at index {ii} for variable '{f}'.")
#Create data and return
return cls(data_sorted[field].values, ranges, order, **kwargs)
#Alias
fromPandas = from_pandas
#########################################################################
#Properties:
@property
def outOfBounds(self) -> str:
"""The current method of handling out-of-bounds access to tabulation."""
return self._outOfBounds.value
@outOfBounds.setter
def outOfBounds(self, outOfBounds:Literal["extrapolate", "fatal", "nan"]):
self.checkType(outOfBounds, str, "outOfBounds")
self._outOfBounds = _OoBMethod(outOfBounds)
#Update interpolator
self._createInterpolator()
####################################
@BaseTabulation.order.setter
def order(self, order:Iterable[str]):
oldOrder = self.order
BaseTabulation.order.fset(self, order)
self._data = self._data.transpose(*[oldOrder.index(o) for o in order])
#Update interpolator
self._createInterpolator()
####################################
@property
def ranges(self):
"""
Get a dict containing the data ranges in the tabulation (read-only).
"""
return {r:self._ranges[r].copy() for r in self._ranges}
#######################################
#Get data:
@property
def data(self):
"""
The data-structure storing the sampling points (read-only).
"""
return self._data.copy()
#######################################
#Get interpolator:
@property
def interpolator(self) -> RegularGridInterpolator:
"""
Returns the interpolator.
"""
return self._interpolator
#######################################
@property
def ndim(self) -> int:
"""
Returns the number of dimentsions of the table.
"""
return self._data.ndim
#######################################
@property
def shape(self) -> tuple[int]:
"""
The shape, i.e., how many sampling points are used for each input-variable.
"""
return self._data.shape
#######################################
@property
def size(self) -> int:
"""
Returns the number of data-points stored in the table.
"""
return self._data.size
#########################################################################
#Constructor:
def __init__(self, data:Iterable[float]|Iterable, ranges:dict[str,Iterable[float]], order:Iterable[str], *, outOfBounds:Literal["extrapolate", "fatal", "nan"]="fatal"):
"""
Construct a tabulation from the data at the interpolation points,
the ranges of each input variable, and the order in which the
input-variables are nested.
Args:
data (Iterable[float]|Iterable): Data structure containing the interpulation values at
sampling points of the tabulation.
- If 1-dimensional array is given, data are stored as a list by recursively looping over the ranges stored in 'ranges', following variable
hierarchy set in 'order'.
- If n-dimensional array is given, shape must be consistent with 'ranges'.
ranges (dict[str,Iterable[float]]): Sampling points used in the tabulation for each input variable.
order (Iterable[str]): Order in which the input variables are nested.
outOfBounds (Literal["extrapolate", "nan", "fatal"], optional): Ho to handle out-of-bound access to the tabulation. Defaults to "fatal".
Raises:
TypeError: If data is a DataFrame. Use 'from_pandas' method to create a Tabulation from a DataFrame.
"""
if isinstance(data, DataFrame):
raise TypeError("Use 'from_pandas' method to create a Tabulation from a DataFrame.")
#Argument checking:
self.checkType(data, Iterable, entryName="data")
data = np.array(data) #Cast to numpy
#Ranges
self.checkMap(ranges, str, Iterable, entryName="ranges")
[self.checkArray(ranges[var], float, f"ranges[{var}]") for var in ranges]
#Check that ranges are in ascending order
for r in ranges:
if not (list(ranges[r]) == sorted(ranges[r])):
raise ValueError(f"Range for variable '{r}' not sorted in ascending order.")
#Order
self.checkArray(order, str,entryName="order")
#Order consistent with ranges
if not(len(ranges) == len(order)):
raise ValueError("Length missmatch. Keys of 'ranges' must be the same of the elements of 'order'.")
for key in ranges:
if not(key in order):
raise ValueError(f"key '{key}' not found in entry 'order'. Keys of 'ranges' must be the same of the elements of 'order'.")
#check size of data
numEl = np.prod([len(ranges[r]) for r in ranges])
if len(data.shape) <= 1:
if not(len(data) == numEl):
raise ValueError("Size of 'data' is not consistent with the data-set given in 'ranges'.")
else:
if not(data.size == numEl):
raise ValueError("Size of 'data' is not consistent with the data-set given in 'ranges'.")
if not(data.shape == tuple([len(ranges[o]) for o in order])):
raise ValueError("Shape of 'data' is not consistent with the data-set given in 'ranges'.")
#Storing copy
ranges = {r:list(ranges[r][:]) for r in ranges}
order = list(order[:])
#Casting to np.array:
for r in ranges:
ranges[r] = np.array(ranges[r])
#Ranges and order:
self._ranges = ranges
self._order = order
self._data = data
#Reshape if given list:
if len(data.shape) == 1:
self._data = self._data.reshape([len(ranges[o]) for o in order])
#Options
self._outOfBounds = _OoBMethod(outOfBounds)
self._createInterpolator()
#########################################################################
#Private member functions:
[docs]
def _createInterpolator(self) -> None:
"""Create the interpolator.
"""
#Create grid:
ranges = []
for f in self.order:
#Check for dimension:
range_ii = self._ranges[f]
if len(range_ii) > 1:
ranges.append(range_ii)
#Remove empty directions
tab = self._data.squeeze()
#Extrapolation method:
opts = {"bounds_error":False}
if self.outOfBounds == _OoBMethod.fatal:
opts.update(bounds_error=True)
elif self.outOfBounds == _OoBMethod.nan:
opts.update(fill_value=float('nan'))
elif self.outOfBounds == _OoBMethod.extrapolate:
opts.update(fill_value=None)
else:
raise ValueError(f"Unexpecred out-of-bound method {self.outOfBounds}")
self._interpolator = RegularGridInterpolator(tuple(ranges), tab, **opts)
#########################################################################
#Public member functions:
append = merge = concat = concat
insertDimension = insertDimension
slice = sliceTable
clip = clipTable
squeeze = squeeze
[docs]
def copy(self):
"""
Create a copy of the tabulation.
"""
return Tabulation(self.data, self.ranges, self.order, outOfBounds=self.outOfBounds)
#Conversion
toPandas = to_pandas = toPandas
#Plotting
plot = plotTable
plotHeatmap = plotTableHeatmap
#Access
[docs]
def setRange(self, variable:str, range:Iterable[float]) -> None:
"""
Change the range of an input variable in the tabulation.
Args:
variable (str): The variable to modify.
range (Iterable[float]): The new range for the variable.
"""
self.checkType(variable, str, "variable")
self.checkArray(range, float, "range")
if not variable in self.order:
raise ValueError(f"Variable '{variable}' not found in the tabulation.")
if not len(range) == len(self._ranges[variable]):
raise ValueError(f"Length of new range for variable '{variable}' not consistent with the current range.")
if not len(set(range)) == len(range):
raise ValueError(f"New range for variable '{variable}' contains duplicates.")
if not list(range) == sorted(range):
raise ValueError(f"New range for variable '{variable}' not sorted in ascending order.")
self._ranges[variable] = np.array(range)
self._createInterpolator()
#########################################################################
#Dunder methods
#Interpolation
[docs]
def __call__(self, *args:tuple[float,...]|tuple[tuple[float,...],...], outOfBounds:str=None) -> float|np.ndarray[float]:
"""
Multi-linear interpolation from the tabulation. The input data must be consistent with the number of input-variables stored in the tabulation.
Args:
*args (tuple[float,...] | Iterable[tuple[float,...]]): The input data to interpolate.
- If tuple[float,...] is given, returns float.
- If tuple[tuple[float,...]] is given, returns np.ndarray[float], where each entry is the result of the interpolation.
outOfBounds (str, optional): Overwrite the out-of-bounds method before interpolation. Defaults to None.
Returns:
float: The return value.
"""
#Check arguments
self.checkType(args, (tuple, Iterable), "args")
#Check for single entry
if not isinstance(args[0], Iterable):
args = [args]
#Pre-processing: check for dimension and extract active dimensions
entries = []
self.checkArray(args, Iterable, "args")
for ii, entry in enumerate(args):
self.checkArray(entry, float, f"args[{ii}]")
#Check for dimension
if len(entry) != self.ndim:
raise ValueError("Number of entries not consistent with number of dimensions stored in the tabulation ({} expected, while {} found).".format(self.ndim, len(entry)))
#extract active dimensions
entries.append([])
for ii, f in enumerate(self.order):
#Check for dimension:
if len(self._ranges[f]) > 1:
entries[-1].append(entry[ii])
else:
if entry[ii] != self._ranges[f][0]:
warnings.warn(
TabulationAccessWarning(
f"Variable '{f}' with only one data-point, cannot " +
"interpolate along that dimension. Entry for that " +
"variable will be ignored.")
)
#Update out-of-bounds
if not outOfBounds is None:
oldOoB = self.outOfBounds
self.outOfBounds = outOfBounds
#Compute
returnValue = self.interpolator(entries)
#Reset oob
if not outOfBounds is None:
self.outOfBounds = oldOoB
#Give results
if len(returnValue) == 1:
return returnValue[0]
else:
return returnValue
#######################################
[docs]
def __getitem__(self, index:int|Iterable[int]|slice) -> float|np.ndarray[float]:
"""
Get an element in the table.
Args:
index (int | Iterable[int] | slice | Iterable[slice]): Either:
- An index to access the table (flattened).
- A tuple of the x,y,z,... indices to access the table.
- A slice to access the table (flattened).
- A tuple of slices to access the table.
Returns:
float | Iterable[float]: The value at the index/indices:
- If int|Iterable[int] is given, returns float.
- If slice|Iterable[slice] is given, returns np.ndarray[float].
"""
# If not list of index/slice, flatten access
if isinstance(index, (int, np.integer, slice)):
return self._data.flatten()[index]
elif isinstance(index, tuple) and all(isinstance(i, (int, np.integer)) for i in index):
return self._data.flatten()[np.ravel_multi_index(index, self.shape)]
return self._data[index]
#######################################
[docs]
def __setitem__(self, index:int|Iterable[int]|slice|tuple[int|Iterable[int]|slice], value:float|np.ndarray[float]) -> None:
"""
Set the interpolation values at a slice of the table through np.ndarray.__setitem__ but:
- If int|Iterable[int]|slice is given, set the value at the index/indices in the flattened dataset.
- If tuple[int|Iterable[int]|slice] is given, set the value at the index/indices in the nested dataset.
"""
try:
#Check nested access
if isinstance(index, tuple):
if len(index) != self.ndim:
raise ValueError("Number of entries not consistent with number of dimensions stored in the tabulation ({} expected, while {} found).".format(self.ndim, len(index)))
#Use ndarray.__setitem__
self._data.__setitem__(index, value)
#Flattened access
elif isinstance(index, (int, np.integer, slice, Iterable)):
if isinstance(index, Iterable):
self.checkArray(index, (int, np.integer), "index")
nestedId = self._computeIndex(index)
if isinstance(nestedId, tuple): #Single index -> convert to list[tuple]
nestedId = [nestedId]
if not isinstance(value, Iterable): #Single value -> convert to list
value = [value]
if not len(value) == len(nestedId):
raise ValueError("Number of entries not consistent with number of dimensions stored in the tabulation ({} expected, while {} found).".format(len(nestedId), len(value)))
for idx, val in zip(nestedId, value):
self._data.__setitem__(idx, val)
else:
raise TypeError("Cannot access with index of type '{}'.".format(index.__class__.__name__))
except BaseException as err:
raise ValueError("Failed setting items in Tabulation: {}".format(err))
#Update interpolator
self._createInterpolator()
#######################################
[docs]
def __eq__(self, value:Tabulation) -> bool:
if not isinstance(value, Tabulation):
raise NotImplementedError("Cannot compare Tabulation with object of type '{}'.".format(value.__class__.__name__))
#Ranges
if False if (self._ranges.keys() != value._ranges.keys()) else any([not np.array_equal(value._ranges[var], self._ranges[var]) for var in self._ranges]):
return False
#Order
if self._order != value._order:
return False
#Data
if not np.array_equal(value._data, self._data):
return False
return True
#######################################
[docs]
def __str__(self):
"""
String representation of the tabulation.
"""
string = super().__str__()
string += self.to_pandas().to_string()
return string
[docs]
def __repr__(self):
"""
Representation of the tabulation.
"""
return super().__repr__() + f"data={self.data})"