sc2_datasets.torch.datasets.sc2_dataset

Classes

SC2Dataset

Inherits from PyTorch Dataset and ensures that the dataset for SC2EGSet is downloaded.

Module Contents

class SC2Dataset(names_urls: list[sc2_datasets.available_replaypacks.DatasetProperties], unpack_dir: pathlib.Path | str = Path('./data/unpack').resolve(), download_dir: pathlib.Path | str = Path('./data/download').resolve(), download: bool = True, unpack_n_workers: int = 16, transform: Callable | None = None, validator: Callable | None = None)

Bases: torch.utils.data.Dataset

Inherits from PyTorch Dataset and ensures that the dataset for SC2EGSet is downloaded.

Parameters:
  • names_urls (list[DatasetProperties]) – Specifies the URL of the dataset which will be used to download the files.

  • unpack_dir (Path | str) – Specifies the path of a directory where the dataset files will be unpacked.

  • download_dir (Path | str) – Specifies the path of a directory where the dataset files will be downloaded.

  • unpack_n_workers (int, optional) – Specifies the number of workers that will be used for unpacking the archive, defaults to 16.

  • transform (Func[SC2ReplayData, T]) – PyTorch transform function that takes SC2ReplayData and returns something.

  • validator (Callable | None, optional) – Specifies the validation option for fetched data, defaults to None.

transform = None
download_dir
unpack_dir
names_urls
download = True
unpack_n_workers = 16
validator = None
skip_files: dict[str, set[str]]
len = 0
ensure_downloaded()

Ensures that the dataset was downloaded before accessing the __len__ or __getitem__ methods.

__len__() int

Returns the number of items that are within the dataset

__getitem__(index: Any) tuple[Any, Any] | sc2_datasets.replay_data.sc2_replay_data.SC2ReplayData

Exposes logic of getting a single parsed item by using dataset[index].

Parameters:

index (Any) – Specifies the index of an item that should be retrieved.

Raises:
  • IndexError – To support negative indexing, if the index is less than zero twice, IndexError is raised.

  • IndexError – If the index is greater than length of the dataset, IndexError is raised.

Returns:

Returns a parsed SC2ReplayData from an underlying SC2ReplaypackDataset, or a result of a transform that was passed to the dataset.

Return type:

tuple[Any, Any] | SC2ReplayData