JustPaste
HomeCategoriesAboutDonateContactTerms of UsePrivacy Policy
JustPaste

Free online notepad — write and share instantly

Navigate

  • Home
  • Timeline
  • Categories

Info

  • About
  • Donate
  • Contact

Legal

  • Terms of Use
  • Privacy Policy

© 2026 JustPaste.app. All rights reserved.

Made with ♥ by JustPaste

vis-snippet | JustPaste.app
about 1 month ago0 views
👨‍💻Programming

vis-snippet

from pathlib import Path
  import matplotlib.pyplot as plt
  import torch

  from pipeline.dataset import ProcessedDataset
  from pipeline import augmentation as aug
  from pipeline import features as feat


  def unaugmented_fields(ds, idx):
      """Return (inputs, targets) as dicts of per-field channel-first tensors,
      i.e. the same things ProcessedDataset builds internally *before* it
      concatenates them into `sample['x']` and *before* it picks a group
      element. No augmentation is applied here."""
      fi, z = ds._index[idx]
      raw = ds._read_fields(ds.h5_files[fi], z)
      raw.pop("__attrs__")
      inputs = {
          name: ds._to_channel_first(name, raw[name], feat.INPUT_FEATURES[name]["kind"])
          for name in ds.input_features
      }
      targets = {
          name: ds._to_channel_first(name, raw[name], feat.OUTPUT_TARGETS[name]["kind"])
          for name in ds.output_targets
      }
      return inputs, targets


  def apply_group_element(inputs, targets, g, dim):
      """Apply a single GroupElement consistently to every field."""
      out_in = {
          n: aug.apply_to_field(t, feat.INPUT_FEATURES[n]["kind"], g, dim)
          for n, t in inputs.items()
      }
      out_tg = {
          n: aug.apply_to_field(t, feat.OUTPUT_TARGETS[n]["kind"], g, dim)
          for n, t in targets.items()
      }
      return out_in, out_tg


  # 2D demo — easiest to read spatially. Swap dim=3 if you want to explore D4.
  ds = ProcessedDataset(
      h5_files=[Path("processed_data/basal_001.h5")],
      dim=2,
      input_features=["fip", "strain", "plastic_strain", "stress", "grid"],
      output_targets=["quaternion"],
      augment=False,
      normalize=False,
  )

  z_idx = 15          # which slice (0..29) of the simulation
  inputs, targets = unaugmented_fields(ds, z_idx)

  # Build the group to match the dataset's wiring: mirror only has meaning in 3D.
  group = aug.build_group(include_mirror=(ds.dim == 3))
  print(f"{len(group)} elements: {[g.name for g in group]}")

  Cell 2 — visualize the effect of every group element

  # Pick one informative channel per field kind.
  #   scalar (fip)      : any rotation is visible directly
  #   voigt  (strain)   : show the xy component (Voigt index 3) — rotates into mixes
  #   vector (grid)     : x-component — rotates into ±y under 90° / 180°
  #   quaternion        : w-channel; the other 3 change via Hamilton product
  CHAN = {"fip": 0, "strain": 3, "grid": 0, "quaternion": 0}
  COLS = [
      ("fip",        "scalar",      "fip"),
      ("strain",     "voigt",       "strain_xy (Voigt idx 3)"),
      ("grid",       "vector",      "grid x"),
      ("quaternion", "quaternion",  "q_w"),
  ]

  fig, axes = plt.subplots(len(group), len(COLS),
                           figsize=(2.8 * len(COLS), 2.8 * len(group)))
  if len(group) == 1:
      axes = axes[None, :]

  for r, g in enumerate(group):
      aug_in, aug_tg = apply_group_element(inputs, targets, g, ds.dim)
      fields = {**aug_in, **aug_tg}
      for c, (name, _, label) in enumerate(COLS):
          ax = axes[r, c]
          img = fields[name][CHAN[name]].numpy()
          im = ax.imshow(img, cmap="RdBu_r" if name == "quaternion" else "viridis")
          if r == 0:
              ax.set_title(label)
          if c == 0:
              ax.set_ylabel(g.name, rotation=0, ha="right", va="center", labelpad=40)
          ax.set_xticks([]); ax.set_yticks([])
  fig.suptitle(f"Augmentation effects, dim={ds.dim}, slice={z_idx}")
  fig.tight_layout()
  plt.show()
← Back to timeline