from pathlib import Path
import matplotlib.pyplot as plt
import torch
from pipeline.dataset import ProcessedDataset
from pipeline import augmentation as aug
from pipeline import features as feat
def unaugmented_fields(ds, idx):
"""Return (inputs, targets) as dicts of per-field channel-first tensors,
i.e. the same things ProcessedDataset builds internally *before* it
concatenates them into `sample['x']` and *before* it picks a group
element. No augmentation is applied here."""
fi, z = ds._index[idx]
raw = ds._read_fields(ds.h5_files[fi], z)
raw.pop("__attrs__")
inputs = {
name: ds._to_channel_first(name, raw[name], feat.INPUT_FEATURES[name]["kind"])
for name in ds.input_features
}
targets = {
name: ds._to_channel_first(name, raw[name], feat.OUTPUT_TARGETS[name]["kind"])
for name in ds.output_targets
}
return inputs, targets
def apply_group_element(inputs, targets, g, dim):
"""Apply a single GroupElement consistently to every field."""
out_in = {
n: aug.apply_to_field(t, feat.INPUT_FEATURES[n]["kind"], g, dim)
for n, t in inputs.items()
}
out_tg = {
n: aug.apply_to_field(t, feat.OUTPUT_TARGETS[n]["kind"], g, dim)
for n, t in targets.items()
}
return out_in, out_tg
# 2D demo — easiest to read spatially. Swap dim=3 if you want to explore D4.
ds = ProcessedDataset(
h5_files=[Path("processed_data/basal_001.h5")],
dim=2,
input_features=["fip", "strain", "plastic_strain", "stress", "grid"],
output_targets=["quaternion"],
augment=False,
normalize=False,
)
z_idx = 15 # which slice (0..29) of the simulation
inputs, targets = unaugmented_fields(ds, z_idx)
# Build the group to match the dataset's wiring: mirror only has meaning in 3D.
group = aug.build_group(include_mirror=(ds.dim == 3))
print(f"{len(group)} elements: {[g.name for g in group]}")
Cell 2 — visualize the effect of every group element
# Pick one informative channel per field kind.
# scalar (fip) : any rotation is visible directly
# voigt (strain) : show the xy component (Voigt index 3) — rotates into mixes
# vector (grid) : x-component — rotates into ±y under 90° / 180°
# quaternion : w-channel; the other 3 change via Hamilton product
CHAN = {"fip": 0, "strain": 3, "grid": 0, "quaternion": 0}
COLS = [
("fip", "scalar", "fip"),
("strain", "voigt", "strain_xy (Voigt idx 3)"),
("grid", "vector", "grid x"),
("quaternion", "quaternion", "q_w"),
]
fig, axes = plt.subplots(len(group), len(COLS),
figsize=(2.8 * len(COLS), 2.8 * len(group)))
if len(group) == 1:
axes = axes[None, :]
for r, g in enumerate(group):
aug_in, aug_tg = apply_group_element(inputs, targets, g, ds.dim)
fields = {**aug_in, **aug_tg}
for c, (name, _, label) in enumerate(COLS):
ax = axes[r, c]
img = fields[name][CHAN[name]].numpy()
im = ax.imshow(img, cmap="RdBu_r" if name == "quaternion" else "viridis")
if r == 0:
ax.set_title(label)
if c == 0:
ax.set_ylabel(g.name, rotation=0, ha="right", va="center", labelpad=40)
ax.set_xticks([]); ax.set_yticks([])
fig.suptitle(f"Augmentation effects, dim={ds.dim}, slice={z_idx}")
fig.tight_layout()
plt.show()0 views