sc2_datasets.torch.datasets.sc2_dataset¶
Classes¶
Inherits from PyTorch Dataset and ensures that the dataset for SC2EGSet is downloaded. |
Module Contents¶
- class SC2Dataset(names_urls: list[sc2_datasets.available_replaypacks.DatasetProperties], unpack_dir: pathlib.Path | str = Path('./data/unpack').resolve(), download_dir: pathlib.Path | str = Path('./data/download').resolve(), download: bool = True, unpack_n_workers: int = 16, transform: Callable | None = None, validator: Callable | None = None)¶
Bases:
torch.utils.data.DatasetInherits from PyTorch Dataset and ensures that the dataset for SC2EGSet is downloaded.
- Parameters:
names_urls (list[DatasetProperties]) – Specifies the URL of the dataset which will be used to download the files.
unpack_dir (Path | str) – Specifies the path of a directory where the dataset files will be unpacked.
download_dir (Path | str) – Specifies the path of a directory where the dataset files will be downloaded.
unpack_n_workers (int, optional) – Specifies the number of workers that will be used for unpacking the archive, defaults to 16.
transform (Func[SC2ReplayData, T]) – PyTorch transform function that takes SC2ReplayData and returns something.
validator (Callable | None, optional) – Specifies the validation option for fetched data, defaults to None.
- transform = None¶
- download_dir¶
- unpack_dir¶
- names_urls¶
- download = True¶
- unpack_n_workers = 16¶
- validator = None¶
- skip_files: dict[str, set[str]]¶
- len = 0¶
- ensure_downloaded()¶
Ensures that the dataset was downloaded before accessing the __len__ or __getitem__ methods.
- __len__() int¶
Returns the number of items that are within the dataset
- __getitem__(index: Any) tuple[Any, Any] | sc2_datasets.replay_data.sc2_replay_data.SC2ReplayData¶
Exposes logic of getting a single parsed item by using dataset[index].
- Parameters:
index (Any) – Specifies the index of an item that should be retrieved.
- Raises:
IndexError – To support negative indexing, if the index is less than zero twice, IndexError is raised.
IndexError – If the index is greater than length of the dataset, IndexError is raised.
- Returns:
Returns a parsed SC2ReplayData from an underlying SC2ReplaypackDataset, or a result of a transform that was passed to the dataset.
- Return type:
tuple[Any, Any] | SC2ReplayData