from ewokscore import Task
from ewoksxrpd.tasks.utils.data_utils import create_hdf5_link
from blissdata.redis_engine.store import DataStore
from blissdata.redis_engine.scan import Scan
from blissdata.beacon.data import BeaconData
from silx.io import h5py_utils
import h5py
import logging
import numpy as np
import json
from pathlib import Path
from pyFAI.io.integration_config import WorkerConfig
from scipy.ndimage import zoom
import matplotlib.image
logger = logging.getLogger(__name__)
[docs]
class CreateDiffMapFile(
Task,
input_names=[
"nxdata_url",
],
optional_input_names=[
"lima_name",
"master_filename",
"scan_nb",
"external_output_filename",
"output_filename",
"nxprocess_name",
"integration_options",
"x_nb_points",
"y_nb_points",
"scan_memory_url",
"normalization_counter",
"dark_value",
"save_img",
"do_diffmap",
],
output_names=[
"nxdata_url",
"filename_img",
],
):
"""
Creates a DiffMap NexusProcess compliant with pyFAI-diffmap/view API after an already integrated dataset
Input names:
nxdata_url : url to stored integrated data. Normally it's stored in the PROCESSED_DATA scan file, example scan_file.h5::16.1/eiger_integrate/integrated
Optional input names:
lima_name: name of the detector device. For example: "eiger"
master_filename: RAW_DATA master filename, kmap information may be stored here
scan_nb: number of the scan
external_output_filename: external file where to write the diffmap nexus process (usually it's the PROCESSED_DATA scan file)
output_filename: file where to write a link to the diffmap nexus process (usually it's the PROCESSED_DATA master file of collection)
nxprocess_name: name of the nexus process. If not provided, will be set to "<lima_name>_diffmap"
integration_options: dictionary to write as a configuration group
x_nb_points: number of points along the x direction. If not provided, will be searched as kmap paramaters in master_filename
y_nb_points: number of points along the y direction. If not provided, will be searched as kmap paramaters in master_filename
do_diffmap: if False, skip this process
* nxdata_url: where to read (integrated_data)
* external_output_filename: where to save the diffmap
* output_filename: where to link the diffmap
"""
[docs]
def run(self):
if not self.get_input_value("do_diffmap", True):
self.outputs.nxdata_url = None
return
nxdata_url: str = self.inputs.nxdata_url
lima_name: str = self.inputs.lima_name
master_filename: str = self.get_input_value("master_filename", None)
scan_number: int = self.get_input_value("scan_nb", None)
external_output_filename: str = self.get_input_value(
"external_output_filename", None
)
output_filename: str = self.get_input_value("output_filename", None)
integration_options: dict = self.get_input_value("integration_options", {})
entry_name = f"{scan_number}.1" if scan_number else "entry_0000"
nxprocess_name = self.get_input_value("nxprocess_name", f"{lima_name}_diffmap")
nxdata_diffmap_name = "diffmap"
input_filename, nxdata_integrate_path = nxdata_url.split("::")
h5_path_parts = nxdata_integrate_path.split("/")
if not lima_name:
lima_name = h5_path_parts[2].replace("_integrate", "")
if not external_output_filename and not output_filename:
external_output_filename = input_filename
output_filename = input_filename
elif not external_output_filename:
external_output_filename = output_filename
elif not output_filename:
output_filename = external_output_filename
# Open the integrated dataset from nxdata_url
mode = "r+" if external_output_filename == input_filename else "r"
with h5py_utils.open_item(
filename=input_filename,
name="/",
mode=mode,
) as file_input:
nxdata_integrate = file_input[nxdata_integrate_path]
dset_intensity_integrated = nxdata_integrate["intensity"]
if dset_intensity_integrated.ndim != 2:
logger.error(
"Multidimensional diffmap is not implemented. Skipping diffmap."
)
self.outputs.nxdata_url = None
return
# Shape of dataset
nb_frames, nbpt_rad = dset_intensity_integrated.shape
# Get the shape of the diffraction map
x_nb_points, y_nb_points = self.get_kmap_points()
if x_nb_points is None or y_nb_points is None:
logger.error(
"x_nb_points and y_nb_points must be provided either as input parameters or found in the master file."
)
self.outputs.nxdata_url = None
return
if nb_frames != (x_nb_points * y_nb_points):
raise ValueError(
f"Number of frames {nb_frames} does not match with kmap points {x_nb_points}*{y_nb_points}={x_nb_points*y_nb_points}"
)
# Open the external file the result into external_output_filename
with h5py_utils.open_item(
filename=external_output_filename, name="/", mode="r+"
) as file_output:
scan_entry = file_output[entry_name]
nxprocess_diffmap = scan_entry.create_group(nxprocess_name)
nxprocess_diffmap.attrs.update(
{"NX_class": "NXprocess", "default": nxdata_diffmap_name}
)
nxprocess_diffmap["dim0"] = y_nb_points
nxprocess_diffmap["dim0"].attrs["axis"] = "motory"
nxprocess_diffmap["dim1"] = x_nb_points
nxprocess_diffmap["dim1"].attrs["axis"] = "motorx"
nxprocess_diffmap["offset"] = 0
if integration_options.get(
"mask_file", None
) and integration_options.get("do_mask", True):
mask_file = integration_options.get("mask_file")
else:
mask_file = ""
nxprocess_diffmap["mask_file"] = mask_file
# NxNote: configuration
config = nxprocess_diffmap.create_group("configuration")
config.attrs["NX_class"] = "NXnote"
pyfai_config = WorkerConfig.from_dict(integration_options)
config["data"] = json.dumps(
pyfai_config.as_dict(), indent=2, separators=(",\r\n", ": ")
)
config["type"] = "text/json"
# NxData: result
nxdata_diffmap = nxprocess_diffmap.create_group(nxdata_diffmap_name)
nxdata_diffmap.attrs["NX_class"] = "NXdata"
nxdata_diffmap["ypoints"] = np.arange(y_nb_points)
nxdata_diffmap["xpoints"] = np.arange(x_nb_points)
# Copy the radial information
axes_integrate = nxdata_integrate.attrs.get("axes", [""])
nxprocess_diffmap["dim2"] = nbpt_rad
radial_name = axes_integrate[-1]
nxprocess_diffmap["dim2"].attrs.update(
{
"axis": "diffraction",
"name": radial_name,
}
)
nxdata_integrated_axis = nxdata_integrate[radial_name]
virtual_source = h5py.VirtualSource(nxdata_integrated_axis)
virtual_layout = h5py.VirtualLayout(
shape=nxdata_integrated_axis.shape,
dtype=nxdata_integrated_axis.dtype,
)
virtual_layout[:] = virtual_source[:]
nxdata_diffmap.create_virtual_dataset(radial_name, virtual_layout)
# Reshape and write the diffmap
intensity_integrated = dset_intensity_integrated[:]
dset_intensity_diffmap = nxdata_diffmap.create_dataset(
name="intensity",
data=np.reshape(
intensity_integrated,
(y_nb_points, x_nb_points, nbpt_rad),
),
dtype="float32",
chunks=(1, 1, nbpt_rad),
fillvalue=np.nan,
)
# NxData attrs
nxdata_diffmap.attrs.update(
{
"interpretation": "image",
"signal": "intensity",
"axes": [radial_name, "ypoints", "xpoints"],
}
)
# Map with shifted dimensions
layout = h5py.VirtualLayout(
shape=(nbpt_rad, y_nb_points, x_nb_points),
dtype=dset_intensity_diffmap.dtype,
)
source = h5py.VirtualSource(dset_intensity_diffmap)
for i in range(y_nb_points):
for j in range(x_nb_points):
layout[:, i, j] = source[i, j]
nxdata_diffmap.create_virtual_dataset(
"map", layout, fillvalue=np.nan
).attrs["interpretation"] = "image"
# Check if normalization of integrated data is possible
norm_values = self.get_normalization_values()
if norm_values is not None and len(norm_values) != nb_frames:
logger.error(
f"Normalization values array has different number of frames ({len(norm_values)}) than the integrated data ({nb_frames})"
)
norm_values = None
if norm_values is not None:
# Save normalization values
dset_normalization_values = nxdata_diffmap.create_dataset(
name=self.get_input_value("normalization_counter"),
data=norm_values,
dtype="float32",
chunks=True,
fillvalue=np.nan,
)
dset_normalization_values.attrs["interpretation"] = "scalar"
# Dark value is the normalization counter value when the beam is off
dark_value = self.get_input_value("dark_value", 0.0)
nxdata_diffmap["dark_value"] = dark_value
norm_values -= dark_value
integrate_intensity_norm = np.where(
norm_values[:, np.newaxis] == 0,
0.0,
dset_intensity_integrated[:] / norm_values[:, np.newaxis],
)
dset_diffmap_intensity_norm = nxdata_diffmap.create_dataset(
name="intensity_norm",
data=np.reshape(
integrate_intensity_norm,
(y_nb_points, x_nb_points, nbpt_rad),
),
dtype="float32",
chunks=(1, 1, nbpt_rad),
fillvalue=np.nan,
)
# Map-norm with shifted dimensions
layout = h5py.VirtualLayout(
shape=(nbpt_rad, y_nb_points, x_nb_points),
dtype=dset_diffmap_intensity_norm.dtype,
)
source = h5py.VirtualSource(dset_diffmap_intensity_norm)
for i in range(y_nb_points):
for j in range(x_nb_points):
layout[:, i, j] = source[i, j]
nxdata_diffmap.create_virtual_dataset(
"map_norm", layout, fillvalue=np.nan
).attrs["interpretation"] = "image"
nxdata_diffmap.attrs["signal"] = (
"map" if norm_values is None else "map_norm"
)
# If the master_filename is provided, we may get a link to the lima measurement
if master_filename and scan_number:
if lima_name not in scan_entry["measurement"]:
scan_entry["measurement"][lima_name] = h5py.ExternalLink(
filename=master_filename,
path=f"{scan_number}.1/measurement/{lima_name}",
)
# Link the diffmap nexus process to output_filename
if output_filename != external_output_filename:
with h5py_utils.open_item(
filename=output_filename,
name=entry_name,
mode="r+",
) as scan_entry:
create_hdf5_link(
parent=scan_entry,
link_name=nxprocess_name,
target=nxprocess_diffmap,
relative=True,
)
# Save a .png image
dset_to_save = (
nxdata_diffmap["map"]
if norm_values is None
else nxdata_diffmap["map_norm"]
)
filename_img = self.save_image(
external_output_filename=external_output_filename,
intensity_dset=dset_to_save[:],
)
self.outputs.nxdata_url = f"{external_output_filename}::{entry_name}/{nxprocess_name}/{nxdata_diffmap_name}"
self.outputs.filename_img = filename_img
[docs]
def get_kmap_points(self) -> tuple:
"""Get a tuple of kmap points from kmap parameters from the master file."""
x_nb_points = self.get_input_value("x_nb_points", None)
y_nb_points = self.get_input_value("y_nb_points", None)
if x_nb_points is not None and y_nb_points is not None:
return (x_nb_points, y_nb_points)
master_filename: str = self.get_input_value("master_filename", None)
scan_number: int = self.get_input_value("scan_nb", None)
if not master_filename or not scan_number:
logger.error(
"kmap parameters must be provided either as x_nb_points and y_nb_points or as master_filename and scan_number."
)
return (None, None)
with h5py_utils.open_item(
filename=master_filename, name=f"{scan_number}.1"
) as scan_entry:
instrument = scan_entry["instrument"]
if "kmap_parameters" not in instrument:
return (None, None)
kmap_parameters = instrument["kmap_parameters"]
x_nb_points = kmap_parameters["x_nb_points"][()]
y_nb_points = kmap_parameters["y_nb_points"][()]
return (x_nb_points, y_nb_points)
[docs]
def get_normalization_values(self) -> np.ndarray:
"""Provides a way to read the monitor array either from a blissdata stream or from the file."""
normalization_counter = self.get_input_value("normalization_counter", None)
if normalization_counter is None:
return
# Retrieve counter data from blissdata stream or from the file
scan_memory_url = self.get_input_value("scan_memory_url", None)
if scan_memory_url:
try:
datastore = DataStore(url=BeaconData().get_redis_data_db())
except Exception as e:
logger.error(
f"Failed to connect to Redis datastore: {e}. Needs a BEACON_HOST environment variable."
)
return
scan = datastore.load_scan(key=scan_memory_url, scan_cls=Scan)
while int(scan.state) != 4:
scan.update()
stream = next(
(
stream
for stream_name, stream in scan.streams.items()
if normalization_counter in stream_name
),
None,
)
if stream is None:
logger.error(
f"Normalization counter {normalization_counter} not found in scan.streams."
)
return None
return stream[:]
else:
master_filename: str = self.get_input_value("master_filename", None)
scan_number: int = self.get_input_value("scan_nb", None)
if not master_filename or not scan_number:
logger.error(
"No master filename or scan number provided for normalization counter."
)
return
with h5py_utils.open_item(
filename=master_filename, name=f"{scan_number}.1"
) as scan_entry:
measurement = scan_entry["measurement"]
if normalization_counter not in measurement:
logger.error(
f"Normalization counter {normalization_counter} not found in {measurement}."
)
return
return measurement[normalization_counter][:]
[docs]
def save_image(
self,
external_output_filename: str,
intensity_dset: np.ndarray,
flip: bool = True,
) -> str:
if self.get_input_value("save_img", True):
average_intensity = np.nanmean(intensity_dset, axis=0)
sh0, sh1 = average_intensity.shape
maxsize = 100
if sh0 >= sh1:
new_shape = (maxsize, int(maxsize * sh1 / sh0))
else:
new_shape = (int(maxsize * sh0 / sh1), maxsize)
average_intensity = zoom(average_intensity, new_shape)
nxprocess_name = self.get_input_value(
"nxprocess_name", f"{self.inputs.lima_name}_diffmap"
)
# Save the image as a .png file
subdir_gallery = Path(external_output_filename).parent.joinpath("gallery")
subdir_gallery.mkdir(parents=True, exist_ok=True)
name = Path(external_output_filename).name.replace(
".h5", f"_{nxprocess_name}.png"
)
filename_img = str(subdir_gallery.joinpath(name))
arr = average_intensity.astype(np.float32)
if flip:
arr = np.flipud(arr)
try:
matplotlib.image.imsave(
fname=filename_img,
arr=arr,
cmap="viridis",
dpi=100,
origin="lower",
)
return filename_img
except Exception as e:
logger.error(f"Failed to save image: {e}")
return ""