Skip to content

graphium.utils

module for utility functions

Argument Checker


graphium.utils.arg_checker

Argument checker module

check_arg_iterator(arg, enforce_type=None, enforce_subtype=None, cast_subtype=True)

Verify if the type is an iterator. If it is None, convert to an empty list/tuple. If it is not a list/tuple/str, try to convert to an iterator. If it is a str or cannot be converted to an iterator, then put the arg inside an iterator. Possibly enforce the iterator type to list or tuple, if enfoce_type is not None. Possibly enforce the subtype to any given type if enforce_subtype is not None, and decide whether to cast the subtype or to throw an error.

Parameters:

Name Type Description Default
arg any type

The input to verify/convert to an iterator (list or tuple). If None, an empty iterator is returned.

required
enforce_type str or type

The type to enforce the iterator. The valid choices are : None, list, tuple, 'none', 'list', 'tuple'. If None, then the iterator type is not enforced.

None
enforce_subtype type, np.dtype or str representing basic type

Verify if all the elements inside the iterator are the desired type. If None, then the sub-type is not enforced. Accepted strings are ['none', 'str', 'list', 'tuple', 'dict', 'int', 'float', 'complex', 'bool', 'callable']

None
cast_subtype bool

If True, then the type specified by enforce_subtype is used to cast the elements inside the iterator. If False, then an error is thrown if the types do not match.

True

Returns:

Name Type Description
output iterator

An iterator based on the input of the desired type (list or tuple) and the desired subtypes.

check_columns_choice(dataframe, columns_choice, extra_accepted_cols=None, enforce_type='list')

Verify if the choice of column columns_choice is inside the dataframe or the extra_accepted_cols. Otherwise, errors are thrown by the sub-functions.

Parameters:

Name Type Description Default
dataframe

(pd.DataFrame) The dataframe on which to verify if the column choice is valid. columns_choice: str, iterator(str) The columns chosen from the dataframe

required
extra_accepted_cols

str, iterator(str) A list

None
enforce_type

str or type The type to enforce the iterator. The valid choices are : None, list, tuple, 'none', 'list', 'tuple'. If None, then the iterator type is not enforced.

'list'

Returns:

Name Type Description
output

iterator A str iterator based on the input of the desired type (list or tuple)

check_list1_in_list2(list1, list2, throw_error=True)

Verify if the list1 (iterator) is included in list2 (iterator). If not, raise an error.

Parameters:

Name Type Description Default
list1, list2

list, tuple or object A list or tuple containing the elements to verify the inclusion. If an object is provided other than a list or tuple, then it is considered as a list of a single element.

required
throw_error

bool Whether to throw an error if list1 is not in list2

True

Returns:

Name Type Description
list1_in_list2

bool A boolean representing the inclusion of list1 in list2. It is returned if throw_error is set to false

Decorators


graphium.utils.decorators

classproperty

Bases: property

Decorator used to declare a class property, defined for the class without needing to instanciate an object.

Example

1
2
3
    @classproperty
    def my_class_property(cls):
        return 5

Dictionary of Tensors


graphium.utils.dict_tensor

DictTensor

Bases: dict

A class that combines the functionality of dict and torch.Tensor. Specifically, it is a dict of Tensor, but it has all the methods and attributes of a Tensor. When a given method or attribute is called, it will be called on each element of the dict, and a new dict is returned.

All methods from torch.Tensor and other functions are expected to work on DictTensor

with the following rules
  • The function receives one DictTensor: The function will be applied on all values of the dictionary. Examples are torch.sum or torch.abs.
  • The function receives two DictTensor: The function will be applied on each pair of Tensor with matching keys. If the keys don't match, an error will be thrown. Example are operators such as dict_tensor1 + dict_tensor2 or dict_tensor1 == dict_tensor2.

The output of the functions that act on the torch.Tensor level is a dict[Any]. If the output is a dict[torch.Tensor], it is converted automatically to DictTensor.

If a given method is available in both dict and torch.Tensor, then the one from Tensor is not available. Exceptions to the above rules are the __dict__ and all comparison methods:

  • __dict__: Not available for this class.
  • __lt__ or <: Supports two-way comparison between DictTensor and number, torch.Tensor or DictTensor.
  • __le__ or <=: Supports two-way comparison between DictTensor and number, torch.Tensor or DictTensor.
  • __eq__ or ==: Supports two-way comparison between DictTensor and number, torch.Tensor or DictTensor.
  • __ne__ or !=: Supports two-way comparison between DictTensor and number, torch.Tensor or DictTensor.
  • __gt__ or >: Supports two-way comparison between DictTensor and number, torch.Tensor or DictTensor.
  • __ge__ or >=: Supports two-way comparison between DictTensor and number, torch.Tensor or DictTensor.

The only major function from torch.Tensor that doesn't work (to my knowledge) is the indexing. Indexing has to be done by manually looping the dictionary.

Example
# Summing a float to a DictTensor
dict_ten = DictTensor({
        "a": torch.zeros(5),
        "b": torch.zeros(2, 3),
        "c": torch.zeros(1, 2),
    })
dict_ten + 2
>>
{'a': tensor([2., 2., 2., 2., 2.]),
'b': tensor([[2., 2., 2.],
        [2., 2., 2.]]),
'c': tensor([[2., 2.]])}

# Summing a DictTensor to a DictTensor
dict_ten = DictTensor({
        "a": torch.zeros(5),
        "b": torch.zeros(2, 3) + 0.1,
        "c": torch.zeros(1, 2) + 0.5,
    })
dict_ten2 = DictTensor({
        "a": torch.zeros(5) + 0.2,
        "b": torch.zeros(2, 3) + 0.3,
        "c": torch.zeros(1, 2) + 0.4,
    })
dict_ten + dict_ten2
>>
{'a': tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000]),
'b': tensor([[0.4000, 0.4000, 0.4000],
        [0.4000, 0.4000, 0.4000]]),
'c': tensor([[0.9000, 0.9000]])}


# Summing a accross the first axis
dict_ten = DictTensor({
        "a": torch.zeros(5) + 0.1,
        "b": torch.zeros(2, 3) + 0.2,
        "c": torch.zeros(1, 2) + 0.5,
    })
dict_ten.sum(axis=0)
>>
{'a': tensor(0.5000),
'b': tensor([0.4000, 0.4000, 0.4000]),
'c': tensor([0.5000, 0.5000])}

# Getting the shape of each tensor
dict_ten = DictTensor({
        "a": torch.zeros(5) + 0.1,
        "b": torch.zeros(2, 3) + 0.2,
        "c": torch.zeros(1, 2) + 0.5,
    })
dict_ten.shape
>>
{'a': torch.Size([5]), 'b': torch.Size([2, 3]), 'c': torch.Size([1, 2])}
__init__(dic)

Take a dictionary of torch.Tensors, and transform it into a DictTensor. Register all the required methods from torch.Tensors, but modify them to work on dictionary instead.

apply(func, *args, **kwargs)

Apply a function on every Tensor "value" of the current dictionary

Parameters:

Name Type Description Default
func Callable

function to be called on each Tensor

required
*args, **kwargs

Additional parameters to the function. Note that the Tensor should be the first input to the function.

required

Returns:

Type Description
Union[DictTensor, Dict[Any]]

DictTensor if the output of the function is a torch.Tensor,

Union[DictTensor, Dict[Any]]

or dict[X] if the output of the function is of typeX.

File System


graphium.utils.fs

copy(source, destination, chunk_size=None, force=False, progress=False, leave_progress=True)

Copy one file to another location across different filesystem (local, S3, GCS, etc).

Parameters:

Name Type Description Default
source Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile]

path or file-like object to copy from.

required
destination Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile]

path or file-like object to copy to.

required
chunk_size int

the chunk size to use. If progress is enabled the chunk size is None, it is set to 2048.

None
force bool

whether to overwrite the destination file it it exists.

False
progress bool

whether to display a progress bar.

False
leave_progress bool

whether to hide the progress bar once the copy is done.

True

exists(path)

Check whether a file exists.

Parameters:

Name Type Description Default
path Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]

a path supported by fsspec such as local, s3, gcs, etc.

required

exists_and_not_empty(path)

Check whether a directory exists and is not empty.

get_basename(path)

Get the basename of a file or a folder.

Parameters:

Name Type Description Default
path Union[str, os.PathLike]

a path supported by fsspec such as local, s3, gcs, etc.

required

get_cache_dir(suffix=None, create=True)

Get a local cache directory. You can append a suffix folder to it and optionnaly create the folder if it doesn't exist.

get_extension(path)

Get the extension of a file.

Parameters:

Name Type Description Default
path Union[str, os.PathLike]

a path supported by fsspec such as local, s3, gcs, etc.

required

get_mapper(path)

Get the fsspec mapper.

Parameters:

Name Type Description Default
path Union[str, os.PathLike]

a path supported by fsspec such as local, s3, gcs, etc.

required

get_size(file)

Get the size of a file given its path. Return None if the size can't be retrieved.

join(*paths)

Join paths together. The first element determine the filesystem to use (and so the separator.

Parameters:

Name Type Description Default
paths

a list of paths supported by fsspec such as local, s3, gcs, etc.

()

mkdir(path, exist_ok=True)

Create directory including potential parents.

rm(path, recursive=False, maxdepth=None)

Delete a file or a directory with all nested files.

Hashing


graphium.utils.hashing

get_md5_hash(object)

MD5 hash of any object. The object is converted to a YAML string before being hashed. This allows for nested dictionaries/lists and for hashing of classes and their attributes.

Moving Average Tracker


graphium.utils.moving_average_tracker

MUP


graphium.utils.mup

apply_infshapes(model, infshapes)

Modified from the regular mup.apply_infshapes by explicitly adding base_dim to the MuReadoutGraphium. This allows the code to work on IPUs.

set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True)

Sets the p.infshape attribute for each parameter p of model.

Code taken from the mup package from Microsoft https://github.com/microsoft/mup. No change except in the apply_inf_shapes, using the one from Graphium instead of mup

Inputs

model: nn.Module instance base: The base model. Can be nn.Module, a dict of shapes, a str, or None. If None, then defaults to model If str, then treated as filename for yaml encoding of a dict of base shapes. rescale_params: assuming the model is initialized using the default pytorch init (or He initialization etc that scale the same way with fanin): If True (default), rescales parameters to have the correct (μP) variances. do_assert:

Output

same object as model, after setting the infshape attribute of each parameter.

Read File


graphium.utils.read_file

Utiles for data parsing

file_opener(filename, mode='r')

File reader stream

parse_sdf_to_dataframe(sdf_path, as_cxsmiles=True, skiprows=None)

Allows to read an SDF file containing molecular informations, convert it to a pandas DataFrame and convert the molecules to SMILES. It also lists a warning of all the molecules that couldn't be read.

Arguments
sdf_path: str
    The full path and name of the sdf file to read
as_cxsmiles: bool, optional
    Whether to use the CXSMILES notation, which preserves atomic coordinates,
    stereocenters, and much more.
    See `https://dl.chemaxon.com/marvin-archive/latest/help/formats/cxsmiles-doc.html`
    (Default = True)
skiprows: int, list
    The rows to skip from dataset. The enumerate index starts from 1 insted of 0.
    (Default = None)

read_file(filepath, as_ext=None, **kwargs)

Allow to read different file format and parse them into a MolecularDataFrame. Supported formats are: * csv (.csv, .smile, .smiles, .tsv) * txt (.txt) * xls (.xls, .xlsx, .xlsm, .xls*) * sdf (.sdf) * pkl (.pkl)

Arguments
filepath: str
    The full path and name of the file to read.
    It also supports the s3 url path.
as_ext: str, Optional
    The file extension used to read the file. If None, the extension is deduced
    from the extension of the file. Otherwise, no matter the file extension,
    the file will be read according to the specified ``as_ext``.
    (Default=None)
**kwargs: All the optional parameters required for the desired file reader.

TODO: unit test to make sure it works well with all extensions

Returns
df: pandas.DataFrame
    The ``pandas.DataFrame`` containing the parsed data

Safe Run


graphium.utils.safe_run

SafeRun

__enter__()

Print that the with-statement started, if self.verbose >= 2

__exit__(type, value, traceback)

Handle the error. Raise it if self.raise_error==True, otherwise ignore it and print it if self.verbose >= 1. Also print that the with-statement is completed if self.verbose >= 2.

__init__(name, raise_error=True, verbose=2)

Run some code with error handling and some printing, using the with statment.

Example

In the example below, the 2+None, an error will be caught and printed.

with SafeRun(name="Addition that fails", raise_error=False):
    2 + None

Parameters:

Name Type Description Default
name str

Name of the code, used for printing

required
raise_error bool

Whether to raise an error, or to catch it and print instead

True
verbose int

The level of verbosity 0: Do not print anything 1: Print only the traced stack when an error is caught and raise_error is False 2: Print headers and footers at the start and exit of the with statement.

2

Spaces


graphium.utils.spaces

Tensor


graphium.utils.tensor

ModuleListConcat

Bases: torch.nn.ModuleList

A list of neural modules similar to torch.nn.ModuleList, but where the modules are applied on the same input and concatenated together, instead of being applied sequentially.

Parameters:

Name Type Description Default
dim int

The dimension for the concatenation

-1
forward(*args, **kwargs)

Apply all layers on the args and kwargs, and concatenate their output alongside the dimension self.dim.

ModuleWrap

Bases: torch.nn.Module

Wrap a function into a torch.nn.Module, with possible *args and **kwargs

Parameters:

Name Type Description Default
func

function to wrap into a module

required
forward(*args, **kwargs)

Calls the function self.func with the arguments self.func(*self.args, *args, **self.kwargs, **kwargs)

arg_in_func(fn, arg)

Check if a function takes the given argument.

Parameters
func

The function to check the argument.

str

The name of the argument.

Returns
res: bool
    True if the function contains the argument, otherwise False.

is_device_cuda(device, ignore_errors=False)

Check wheter the given device is a cuda device.

Parameters:

Name Type Description Default
device torch.device

str, torch.device object to check for cuda

required
ignore_errors bool

bool Whether to ignore the error if the device is not recognized. Otherwise, False is returned in case of errors.

False

Returns:

Name Type Description
is_cuda bool

bool

is_dtype_numpy_array(dtype)

Verify if the dtype is a numpy dtype

Parameters:

Name Type Description Default
dtype Union[np.dtype, torch.dtype]

dtype The dtype of a value. E.g. np.int32, str, torch.float

required

Returns:

Type Description
bool

A boolean saying if the dtype is a numpy dtype

is_dtype_torch_tensor(dtype)

Verify if the dtype is a torch dtype

Parameters:

Name Type Description Default
dtype Union[np.dtype, torch.dtype]

dtype The dtype of a value. E.g. np.int32, str, torch.float

required

Returns:

Type Description
bool

A boolean saying if the dtype is a torch dtype

nan_mad(input, normal=True, **kwargs)

Return the median absolute deviation of all elements, while ignoring the NaNs.

Parameters:

Name Type Description Default
input Tensor

The input tensor.

required
normal bool

whether to multiply the result by 1.4826 to mimic the standard deviation for normal distributions.

True
dim int or tuple(int

The dimension or dimensions to reduce.

required
keepdim bool

whether the output tensor has dim retained or not.

required
dtype torch.dtype

The desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

required

Returns:

Name Type Description
output Tensor

The resulting median absolute deviation of the tensor

nan_mean(input, *args, **kwargs)

Return the mean of all elements, while ignoring the NaNs.

Parameters:

Name Type Description Default
input Tensor

The input tensor.

required
dim int or tuple(int

The dimension or dimensions to reduce.

required
keepdim bool

whether the output tensor has dim retained or not.

required
dtype torch.dtype

The desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

required

Returns:

Name Type Description
output Tensor

The resulting mean of the tensor

nan_median(input, **kwargs)

Return the median of all elements, while ignoring the NaNs. Contrarily to torch.nanmedian, this function supports a list of dimensions, or dim=None, and does not return the index of the median

Parameters:

Name Type Description Default
input Tensor

The input tensor.

required
dim int or tuple(int

The dimension or dimensions to reduce.

required
keepdim bool

whether the output tensor has dim retained or not.

required
dtype torch.dtype

The desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

required

Returns:

Name Type Description
output Tensor

The resulting median of the tensor. Contrarily to torch.median, it does not return the index of the median

nan_std(input, unbiased=True, **kwargs)

Return the standard deviation of all elements, while ignoring the NaNs. If unbiased is True, Bessel’s correction will be used. Otherwise, the sample deviation is calculated, without any correction.

Parameters:

Name Type Description Default
input Tensor

The input tensor.

required
unbiased bool

whether to use Bessel’s correction (δN=1\delta N = 1δN=1).

True
dim int or tuple(int

The dimension or dimensions to reduce.

required
keepdim bool

whether the output tensor has dim retained or not.

required
dtype torch.dtype

The desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

required

Returns:

Name Type Description
output Tensor

The resulting standard deviation of the tensor

nan_var(input, unbiased=True, **kwargs)

Return the variace of all elements, while ignoring the NaNs. If unbiased is True, Bessel’s correction will be used. Otherwise, the sample deviation is calculated, without any correction.

Parameters:

Name Type Description Default
input Tensor

The input tensor.

required
unbiased bool

whether to use Bessel’s correction (δN=1\delta N = 1δN=1).

True
dim int or tuple(int

The dimension or dimensions to reduce.

required
keepdim bool

whether the output tensor has dim retained or not.

required
dtype torch.dtype

The desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.

required

Returns:

Name Type Description
output Tensor

The resulting variance of the tensor

one_of_k_encoding(val, classes)

Converts a single value to a one-hot vector.

Parameters:

Name Type Description Default
val Any

int class to be converted into a one hot vector (integers from 0 to num_classes).

required
num_classes

iterator a list or 1D array of allowed choices for val to take

required

Returns:

Type Description
List[int]

A list of length len(num_classes) + 1

parse_valid_args(param_dict, fn)

Check if a function takes the given argument.

Parameters
func

The function to check the argument.

dict

Dictionary of the argument.

Returns
param_dict: dict
    Valid paramter dictionary for the given fucntions.