"""
Example script for creating a PyTorch DataLoader with efficient chunked sampling accessing snapshot data.

This script demonstrates how to:
- Create a custom Dataset class for accessing the S3 bucket containing the dataset
- Implement a ChunkSampler for efficient batch loading
- Load snapshot data from S3-hosted Zarr store
- Tune DataLoader parameters for optimal performance

Requirements:
- torch
- s3fs
- zarr
"""

import math
import random
import time

import torch
import zarr
from torch.utils.data import DataLoader, Dataset, Sampler


# Chunk sampler for efficient data loading
class ChunkSampler(Sampler):
    """Custom sampler that groups data indices into chunks for efficient batch loading."""

    def __init__(self, num_samples, chunk_size):
        self.num_samples, self.chunk_size = num_samples, chunk_size
        self.num_chunks = math.ceil(num_samples / chunk_size)

    def __iter__(self):
        # Shuffle chunks to randomise loading order
        chunk_indices = list(range(self.num_chunks))
        random.shuffle(chunk_indices)
        for chunk_idx in chunk_indices:
            # Generate indices for current chunk and shuffle them
            start_idx = chunk_idx * self.chunk_size
            end_idx = min(start_idx + self.chunk_size, self.num_samples)
            indices = list(range(start_idx, end_idx))
            random.shuffle(indices)
            yield from indices

    def __len__(self):
        return self.num_samples


# Snapshot dataset for loading Zarr-formatted snapshot data with chunked access patterns
class SnapshotDataset(Dataset):
    """
    Custom Dataset for loading Zarr-formatted snapshot data with chunked access patterns.

    This dataset loads chunks of data into memory to minimise I/O operations when accessing
    on-chunk indices, which is particularly important for performance.
    """

    def __init__(self, path, array, num_samples, chunk_size, x_slice=slice(None)):
        self.path = path
        self.array = array
        self.num_samples = num_samples
        self.chunk_size = chunk_size
        self.x_slice = x_slice
        self.zarr_array = None
        self.current_chunk_id = None
        self.chunk_buffer = None

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Lazy initialisation of Zarr array connection
        if self.zarr_array is None:
            self.zarr_array = zarr.open(
                self.path,
                mode="r",
                use_consolidated=True,                          # Use consolidated metadata for faster access
                storage_options={
                    "anon": True,                               # Anonymous access (no credentials required)
                    "endpoint_url": "https://s3.eidf.ac.uk",    # EIDF S3 endpoint
                },
            )["snapshots"][self.array]

        # Determine which chunk this index belongs to
        chunk_id, local_idx = divmod(idx, self.chunk_size)

        # Load new chunk if needed
        if chunk_id != self.current_chunk_id:
            chunk_start = chunk_id * self.chunk_size
            chunk_end = min(chunk_start + self.chunk_size, self.num_samples)
            self.chunk_buffer = self.zarr_array[chunk_start:chunk_end, self.x_slice]
            self.current_chunk_id = chunk_id

        return torch.from_numpy(self.chunk_buffer[local_idx])


# Dataset parameters
num_samples = 2432  # Total number of snapshots
chunk_size = 38     # Chunk size along the time dimension

# DataLoader parameters (tune for optimal performance)
batch_size = chunk_size # Align with chunk size for efficient loading
num_workers = 4         # Number of workers for DataLoader
prefetch_factor = 2     # Number of batches to prefetch in each worker

# Simulate model training
model_time = 1  # Simulated time per model training step (seconds)

# Create dataset and custom chunk sampler
dataset = SnapshotDataset(
    path="s3://eidf198-highres-snapshots-sublayer-dns-tbl-re2400/data.zarr",
    array="u",
    num_samples=num_samples,
    chunk_size=chunk_size,
    x_slice=slice(0, 8),    # Extract a specific x-coordinate/Reynolds number range
)
sampler = ChunkSampler(num_samples=len(dataset), chunk_size=dataset.chunk_size)

# Create DataLoader with the custom sampler
dataloader = DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    sampler=sampler,
    num_workers=num_workers,
    prefetch_factor=prefetch_factor,
    in_order=False,             # Serve samples as soon as they are ready
    persistent_workers=True,    # Keep workers alive for multiple epochs
)

# Get a single batch
batch = next(iter(dataloader))
print("\nBatch Info:")
print(f"shape: {batch.shape}, dtype: {batch.dtype}, nbytes: {batch.nbytes / 1024**2:.2f} MiB")

# Print ideal (model) throughput
# Note that data throughput will ultimately be constrained by local network bandwidth
ideal_throughput = batch.nbytes / model_time / 1024**2
print(f"\nIdeal (Model) Throughput = {ideal_throughput:.2f} MiB/s")

# Iterate through batches
t0 = time.perf_counter()
tm1 = t0
for i, batch in enumerate(dataloader):
    # Simulate model and calculate throughput
    time.sleep(model_time)
    ti = time.perf_counter()
    throughput_batch = batch.nbytes / (ti - tm1) / 1024**2
    throughput_batch_pct = 100 * throughput_batch / ideal_throughput
    throughput_total = (i + 1) * batch.nbytes / (ti - t0) / 1024**2
    throughput_total_pct = 100 * throughput_total / ideal_throughput
    batch_str = f"Batch Throughput = {throughput_batch:.2f} MiB/s ({throughput_batch_pct:.2f}%)"
    total_str = f"Total Throughput = {throughput_total:.2f} MiB/s ({throughput_total_pct:.2f}%)"
    print(f"Batch {i + 1} / {len(dataloader)}: {batch_str}, {total_str}")
    tm1 = time.perf_counter()
