Multi-dimensional Data Analysis with Xarray

Contents

18. Multi-dimensional Data Analysis with Xarray#

18.1. Introduction#

18.2. Learning Objectives#

18.3. Understanding Xarray’s Data Model#

18.3.1. Core Data Structures#

18.3.2. Why This Structure Matters#

18.4. Setting Up Your Environment#

18.4.1. Installing Required Packages#

%pip install xarray pooch pygis

18.4.2. Importing Libraries and Configuration#

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

# Configure Xarray for better display and performance
xr.set_options(keep_attrs=True, display_expand_data=False)

# Configure NumPy display for cleaner output
np.set_printoptions(threshold=10, edgeitems=2)

# Configure matplotlib for better plots
plt.rcParams["figure.dpi"] = 150

18.5. Loading and Exploring Real Climate Data#

18.5.1. Loading Tutorial Data#

# Load a climate dataset with air temperature measurements
ds = xr.tutorial.open_dataset("air_temperature")
ds

18.6. Working with DataArrays#

18.6.1. Accessing DataArrays from Datasets#

# Extract the air temperature DataArray using dictionary notation
temperature = ds["air"]
temperature
# Same result using attribute access
temperature = ds.air
temperature

18.6.2. Exploring DataArray Components#

# Examine the actual data values (a NumPy array)
print("Data shape:", temperature.values.shape)
print("Data type:", temperature.values.dtype)
print("First few values:", temperature.values.flat[:5])
# Understand the dimension structure
print("Dimensions:", temperature.dims)
print("Dimension sizes:", temperature.sizes)
# Explore the coordinate information
print("Coordinates:")
for name, coord in temperature.coords.items():
    print(f"  {name}: {coord.values[:3]}... (showing first 3 values)")
# Examine metadata attributes
print("Attributes:")
for key, value in temperature.attrs.items():
    print(f"  {key}: {value}")

18.7. Intuitive Data Selection and Indexing#

18.7.1. Label-Based Selection#

# Select data for a specific date and location
point_data = temperature.sel(time="2013-01-01", lat=40.0, lon=260.0)
point_data

18.7.2. Time Range Selection#

# Select all data for January 2013
january_data = temperature.sel(time=slice("2013-01-01", "2013-01-31"))
print(f"January 2013 data shape: {january_data.shape}")
print(f"Time range: {january_data.time.values[0]} to {january_data.time.values[-1]}")

18.7.3. Nearest Neighbor Selection#

# Select data nearest to a location that might not be exactly on the grid
nearest_data = temperature.sel(lat=40.5, lon=255.7, method="nearest")
actual_coords = nearest_data.sel(time="2013-01-01")
print(f"Requested: lat=40.5, lon=255.7")
print(f"Actual: lat={actual_coords.lat.values}, lon={actual_coords.lon.values}")

18.8. Performing Operations on Multi-Dimensional Data#

18.8.1. Statistical Operations Across Dimensions#

# Calculate the temporal mean (average temperature at each location)
mean_temperature = temperature.mean(dim="time")
print(f"Original data shape: {temperature.shape}")
print(f"Time-averaged data shape: {mean_temperature.shape}")
print(
    f"Temperature range: {mean_temperature.min().values:.1f} to {mean_temperature.max().values:.1f} K"
)

18.8.2. Computing Anomalies#

# Calculate temperature anomalies by subtracting the time mean from each time step
anomalies = temperature - mean_temperature
print(f"Anomaly range: {anomalies.min().values:.1f} to {anomalies.max().values:.1f} K")

# Find the location and time of the largest positive anomaly
max_anomaly = anomalies.max()
max_location = anomalies.where(anomalies == max_anomaly, drop=True)
print(f"Largest positive anomaly: {max_anomaly.values:.1f} K")

18.8.3. Spatial Statistics#

# Calculate area-weighted spatial mean for each time step
spatial_mean = temperature.mean(dim=["lat", "lon"])
print(f"Spatial mean temperature time series shape: {spatial_mean.shape}")

# Find the warmest and coldest time periods
warmest_date = spatial_mean.time[spatial_mean.argmax()]
coldest_date = spatial_mean.time[spatial_mean.argmin()]
print(f"Warmest period: {warmest_date.values}")
print(f"Coldest period: {coldest_date.values}")

18.9. Data Visualization with Xarray#

18.9.1. Plotting 2D Spatial Data#

# Create a map of long-term average temperature
fig, ax = plt.subplots(figsize=(12, 6))
mean_temperature.plot(ax=ax, cmap="RdYlBu_r", add_colorbar=True)
plt.title("Long-term Average Air Temperature", fontsize=14, fontweight="bold")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.tight_layout()
plt.show()

18.9.2. Customizing Spatial Plots#

# Create a more customized visualization
fig, ax = plt.subplots(figsize=(12, 6))
plot = mean_temperature.plot(
    ax=ax,
    cmap="RdYlBu_r",
    levels=20,  # Number of contour levels
    add_colorbar=True,
    cbar_kwargs={"label": "Temperature (K)", "shrink": 0.8, "pad": 0.02},
)
plt.title("Mean Air Temperature (2013)", fontsize=16, fontweight="bold")
plt.xlabel("Longitude (°E)", fontsize=12)
plt.ylabel("Latitude (°N)", fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

18.9.3. Time Series Visualization#

# Select and plot time series for a specific location
location_ts = temperature.sel(lat=40.0, lon=260.0)

fig, ax = plt.subplots(figsize=(12, 6))
location_ts.plot(ax=ax, linewidth=1.5, color="darkblue")
plt.title("Temperature Time Series at 40°N, 260°E", fontsize=14, fontweight="bold")
plt.xlabel("Time")
plt.ylabel("Temperature (K)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

18.10. Working with Datasets: Multiple Variables#

18.10.1. Exploring Dataset Structure#

# Examine all variables in the dataset
print("Data variables in the dataset:")
for var_name, var_info in ds.data_vars.items():
    print(f"  {var_name}: {var_info.dims}, shape {var_info.shape}")

print(f"\nShared coordinates: {list(ds.coords.keys())}")
print(f"Global attributes: {len(ds.attrs)} metadata items")

18.10.2. Dataset-Level Operations#

# Calculate temporal statistics for all variables in the dataset
dataset_means = ds.mean(dim="time")
dataset_means

18.11. The Power of Label-Based Operations#

18.11.1. The NumPy Approach: Index-Based Selection#

# Extract raw arrays and coordinates
lat_values = ds.air.lat.values
lon_values = ds.air.lon.values
temp_values = ds.air.values

print(f"Data shape: {temp_values.shape}")
print("To plot the first time step, you need to remember:")
print("- Time is dimension 0")
print("- Latitude is dimension 1")
print("- Longitude is dimension 2")
# Plot using NumPy approach - requires careful index management
fig, ax = plt.subplots(figsize=(12, 6))
im = ax.pcolormesh(lon_values, lat_values, temp_values[0, :, :], cmap="RdYlBu_r")
plt.colorbar(im, ax=ax, label="Temperature (K)")
plt.title("First Time Step (NumPy approach)")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.show()

18.11.2. The Xarray Approach: Label-Based Selection#

# Same result with Xarray - much more readable and less error-prone
ds.air.isel(time=0).plot(figsize=(12, 6), cmap="RdYlBu_r")
plt.title("First Time Step (Xarray approach)")
plt.show()
# Select by actual date rather than array index
ds.air.sel(time="2013-01-01T00:00:00").plot(figsize=(12, 6), cmap="RdYlBu_r")
plt.title("Temperature on January 1, 2013")
plt.show()

18.12. Advanced Indexing Techniques#

18.12.1. Position-Based vs. Label-Based Indexing#

# Position-based indexing using isel() - useful for systematic sampling
first_last_times = ds.air.isel(time=[0, -1])  # First and last time steps
print(f"Selected time steps: {first_last_times.time.values}")

# Label-based indexing using sel() - useful for specific values
specific_months = ds.air.sel(time=slice("2013-05", "2013-07"))
print(f"May-July 2013 contains {len(specific_months.time)} time steps")

18.12.2. Boolean Indexing and Conditional Selection#

# Find locations where average temperature exceeds a threshold
warm_locations = mean_temperature.where(mean_temperature > 280)  # 280 K ≈ 7°C
warm_count = warm_locations.count()
print(f"Number of grid points with mean temperature > 280 K: {warm_count.values}")

# Find time periods when spatial average temperature was unusually high
temp_threshold = spatial_mean.quantile(0.9)  # 90th percentile
warm_periods = spatial_mean.where(spatial_mean > temp_threshold, drop=True)
print(f"Number of exceptionally warm time periods: {len(warm_periods)}")

18.13. High-Level Computational Operations#

18.13.1. GroupBy Operations for Temporal Analysis#

# Calculate seasonal climatology
seasonal_means = ds.air.groupby("time.season").mean()
print("Seasonal temperature patterns:")
seasonal_means
# Visualize seasonal patterns
fig, axes = plt.subplots(2, 2, figsize=(12, 6))
seasons = ["DJF", "MAM", "JJA", "SON"]
season_names = ["Winter", "Spring", "Summer", "Fall"]

for i, (season, name) in enumerate(zip(seasons, season_names)):
    ax = axes[i // 2, i % 2]
    seasonal_means.sel(season=season).plot(ax=ax, cmap="RdYlBu_r", add_colorbar=False)
    ax.set_title(f"{name} ({season})")
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")

plt.tight_layout()
plt.show()

18.13.2. Rolling Window Operations#

# Create a smoothed time series using a rolling window
location_data = temperature.sel(lat=40.0, lon=260.0)

fig, ax = plt.subplots(figsize=(12, 6))

# Plot original data
location_data.plot(ax=ax, alpha=0.5, label="Original", color="lightblue")

# Plot smoothed data using a 30-day rolling window
smoothed_data = location_data.rolling(time=30, center=True).mean()
smoothed_data.plot(ax=ax, label="30-day smoothed", color="darkblue", linewidth=2)

plt.title("Temperature Time Series: Original vs Smoothed")
plt.xlabel("Time")
plt.ylabel("Temperature (K)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

18.13.3. Weighted Operations#

# Create simple area weights (this is a simplified example)
# In practice, you would use proper latitude-based area weighting
lat_weights = np.cos(np.radians(ds.air.lat))
area_weighted_mean = ds.air.weighted(lat_weights).mean(dim=["lat", "lon"])

# Compare simple vs area-weighted spatial averages
fig, ax = plt.subplots(figsize=(12, 6))
spatial_mean.plot(ax=ax, label="Simple average", alpha=0.7)
area_weighted_mean.plot(ax=ax, label="Area-weighted average", linewidth=2)
plt.title("Spatial Temperature Averages: Simple vs Area-Weighted")
plt.xlabel("Time")
plt.ylabel("Temperature (K)")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

18.14. Data Input and Output#

18.14.1. Understanding NetCDF Format#

18.14.2. Writing Data to NetCDF#

# Prepare data for saving (ensure proper data types)
output_ds = ds.copy()
output_ds["air"] = output_ds["air"].astype("float32")  # Reduce file size

# Add processing metadata
output_ds.attrs["processing_date"] = str(np.datetime64("now"))
output_ds.attrs["created_by"] = "GIS Pro"

# Save to NetCDF file
output_ds.to_netcdf("processed_air_temperature.nc")
print("Dataset saved to processed_air_temperature.nc")

18.14.3. Reading Data from NetCDF#

# Load the saved dataset
reloaded_ds = xr.open_dataset("processed_air_temperature.nc")
print("Successfully reloaded dataset:")
print(f"Variables: {list(reloaded_ds.data_vars.keys())}")
print(f"Processing date: {reloaded_ds.attrs.get('processing_date', 'Not specified')}")
print(f"Data matches original: {reloaded_ds.air.equals(ds.air.astype('float32'))}")

18.15. Key Takeaways#

18.16. Further Reading#

18.17. Exercises#

18.17.1. Exercise 1: Exploring a New Dataset#

18.17.2. Exercise 2: Data Selection and Indexing#

18.17.3. Exercise 3: Performing Arithmetic Operations#

18.17.4. Exercise 4: GroupBy and Temporal Analysis#

18.17.5. Exercise 5: Data Storage and Retrieval#