from __future__ import absolute_import, division, print_function
from six.moves import range
import os
import h5py
import numpy as np
from xfel.euxfel.read_geom import read_geom
from libtbx.phil import parse
import six
from libtbx.utils import Sorry
import datetime

phil_scope = parse("""
  unassembled_file = None
    .type = path
    .help = hdf5 file used to read in image data.
  geom_file = None
    .type = path
    .help = geometry file to be read in for detector (.geom).
  output_file = None
    .type = path
    .help = output file path
  detector_distance = None
    .type = float
    .help = Detector distance
  wavelength = None
    .type = float
    .help = If not provided, try to find wavelength in unassembled file.
  trusted_range = None
    .type = floats(size=2)
    .help = Set the trusted range
  raw = False
    .type = bool
    .help = Whether the data being analyzed is raw data from the JF4M or has \
            been corrected and padded.
  nexus_details {
    instrument_name = PAL
      .type = str
      .help = Name of instrument
    instrument_short_name = PAL
      .type = str
      .help = short name for instrument, perhaps the acronym
    source_name = PAL
      .type = str
      .help = Name of the neutron or x-ray storage ring/facility
    source_short_name = PAL
      .type = str
      .help = short name for source, perhaps the acronym
    start_time = None
      .type = str
      .help = ISO 8601 time/date of the first data point collected in UTC, \
              using the Z suffix to avoid confusion with local time
    end_time = None
      .type = str
      .help = ISO 8601 time/date of the last data point collected in UTC, \
              using the Z suffix to avoid confusion with local time. \
              This field should only be filled when the value is accurately \
              observed. If the data collection aborts or otherwise prevents \
              accurate recording of the end_time, this field should be omitted
    end_time_estimated = None
      .type = str
      .help = ISO 8601 time/date of the last data point collected in UTC, \
              using the Z suffix to avoid confusion with local time. \
              This field may be filled with a value estimated before an \
              observed value is avilable.
    sample_name = None
      .type = str
      .help = Descriptive name of sample
    total_flux = None
      .type = float
      .help = flux incident on beam plane in photons per second
  }
""")


'''

This script creates a master nexus file by taking in as input a) an hdf5 file and b) a .geom file
The hd5f file is generated by the JF4M after processing the raw images and doing appropriate gain corrections
The assumed parameters for the detector can be seen in the __init__ function and should be changed
if they are modified at in the future

'''


class jf4m_geom2nexus(object):
  def __init__(self, args):
    self.params_from_phil(args)
    self.hierarchy = read_geom(self.params.geom_file)

  def params_from_phil(self, args):
    user_phil = []
    for arg in args:
      if os.path.isfile(arg):
        user_phil.append(parse(file_name=arg))
      else:
        try:
          user_phil.append(parse(arg))
        except Exception as e:
          raise Sorry("Unrecognized argument: %s"%arg)
    self.params = phil_scope.fetch(sources=user_phil).extract()

  def _create_scalar(self, handle,path,dtype,value):
    dataset = handle.create_dataset(path, (),dtype=dtype)
    dataset[()] = value

  def create_vector(self,handle, name, value, **attributes):
    handle.create_dataset(name, (1,), data = [value], dtype='f')
    for key,attribute in six.iteritems(attributes):
      handle[name].attrs[key] = attribute

  def create_nexus_master_file(self):

    '''
    Hierarchical structure of master nexus file. Format information available here
    http://download.nexusformat.org/sphinx/classes/base_classes/NXdetector_module.html#nxdetector-module
    --> entry
      --> data
      --> definition (leaf)
      --> instrument
      --> sample
    '''
    output_file_name = self.params.output_file if self.params.output_file is not None else os.path.splitext(self.params.unassembled_file)[0]+'_master.h5'
    f = h5py.File(output_file_name, 'w')
    f.attrs['NX_class'] = 'NXroot'
    f.attrs['file_name'] = os.path.basename(output_file_name)
    f.attrs['file_time'] = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ")
    f.attrs['HDF5_Version'] = h5py.version.hdf5_version
    entry = f.create_group('entry')
    entry.attrs['NX_class'] = 'NXentry'
    if self.params.nexus_details.start_time: entry['start_time'] = self.params.nexus_details.start_time
    if self.params.nexus_details.end_time: entry['end_time'] = self.params.nexus_details.end_time
    if self.params.nexus_details.end_time_estimated: entry['end_time_estimated'] = self.params.nexus_details.end_time_estimated

    uf = h5py.File(self.params.unassembled_file, 'r')
    run = "R%04d"%uf['header/RunNumber'][()]

    # --> definition
    self._create_scalar(entry, 'definition', 'S4', np.string_('NXmx'))
    # --> data
    data = entry.create_group('data')
    data.attrs['NX_class'] = 'NXdata'
    data_key = 'data'
    if self.params.raw:
      raise NotImplementedError('Cannot do raw data yet')
    else:
      data[data_key] = h5py.ExternalLink(self.params.unassembled_file, "%s/scan_dat/JF4M01_A_data"%run)
    # --> sample
    sample = entry.create_group('sample')
    sample.attrs['NX_class'] = 'NXsample'
    if self.params.nexus_details.sample_name: sample['name'] = self.params.nexus_details.sample_name
    sample['depends_on'] = '.' # This script does not support scans/gonios
    # --> source
    source = entry.create_group('source')
    source.attrs['NX_class'] = 'NXsource'
    source['name'] = self.params.nexus_details.source_name
    source['name'].attrs['short_name'] = self.params.nexus_details.source_short_name
    # --> instrument
    instrument = entry.create_group('instrument')
    instrument.attrs['NX_class'] = 'NXinstrument'
    instrument["name"] = self.params.nexus_details.instrument_name
    instrument["name"].attrs["short_name"] = self.params.nexus_details.instrument_short_name
    beam = instrument.create_group('beam')
    beam.attrs['NX_class'] = 'NXbeam'
    if self.params.nexus_details.total_flux:
      self._create_scalar(beam, 'total_flux', 'f', self.params.nexus_details.total_flux)
      beam['total_flux'].attrs['units'] = 'Hz'
    if self.params.wavelength is None:
      # Need to use a virtual dataset here because we need to set the units attribute
      # on it and we can't do that on an ExternalLink
      wavelengths_path = '%s/scan_dat/photon_wavelength'%run
      wavelengths_shape = uf[wavelengths_path].shape
      layout = h5py.VirtualLayout(shape = wavelengths_shape, dtype='f')
      vsource = h5py.VirtualSource(self.params.unassembled_file, wavelengths_path, wavelengths_shape)
      layout[:] = vsource
      beam.create_virtual_dataset('incident_wavelength', layout)
    else:
      beam.create_dataset('incident_wavelength', (1,), data=self.params.wavelength, dtype='f8')
    beam['incident_wavelength'].attrs['units'] = 'angstrom'
    pal_group = instrument.create_group('PAL')
    pal_group.attrs['NX_class'] = 'NXdetector_group'
    pal_group.create_dataset('group_index', data = list(range(1,3)), dtype='i')
    data = [np.string_('PAL'),np.string_('JF4M')]
    pal_group.create_dataset('group_names',(2,), data=data, dtype='S12')
    pal_group.create_dataset('group_parent',(2,), data=[-1,1], dtype='i')
    pal_group.create_dataset('group_type', (2,), data=[1,2], dtype='i')
    detector = instrument.create_group('JF4M')
    detector.attrs['NX_class']  = 'NXdetector'
    detector['description'] = 'JUNGFRAU 4M'
    detector['depends_on'] = '/entry/instrument/JF4M/transformations/AXIS_RAIL'
    detector['gain_setting'] = 'auto'
    detector['sensor_material'] = 'Si'
    self._create_scalar(detector, 'sensor_thickness', 'f', 320.)
    self._create_scalar(detector, 'bit_depth_readout', 'i', 16)
    self._create_scalar(detector, 'count_time', 'f', 10)
    self._create_scalar(detector, 'frame_time', 'f', 40)
    detector['sensor_thickness'].attrs['units'] = 'microns'
    detector['count_time'].attrs['units'] = 'us'
    detector['frame_time'].attrs['units'] = 'us'
    transformations = detector.create_group('transformations')
    transformations.attrs['NX_class'] = 'NXtransformations'
    # Create AXIS leaves for RAIL, D0 and different hierarchical levels of detector
    if self.params.detector_distance is None:
      distance = uf['%s/header/detector_0_distance'%run][()]
    else:
      distance = self.params.detector_distance
    self.create_vector(transformations, 'AXIS_RAIL', distance, depends_on='.', equipment='detector', equipment_component='detector_arm',transformation_type='translation', units='mm', vector=(0., 0., 1.), offset=(0.,0.,0.))
    self.create_vector(transformations, 'AXIS_D0', 0.0, depends_on='AXIS_RAIL', equipment='detector', equipment_component='detector_arm',transformation_type='rotation', units='degrees', vector=(0., 0., -1.), offset=self.hierarchy.local_origin, offset_units = 'mm')

    panels = []
    for m, module in six.iteritems(self.hierarchy):
      panels.extend([module[key] for key in module])
    pixel_size = panels[0]['pixel_size']
    assert [pixel_size == panels[i+1]['pixel_size'] for i in range(len(panels)-1)].count(False) == 0

    array_name = 'ARRAY_D0'

    if self.params.trusted_range is not None:
      underload, overload = self.params.trusted_range
      detector.create_dataset('underload_value', (1,), data=[underload], dtype='int32')
      detector.create_dataset('saturation_value', (1,), data=[overload], dtype='int32')

    alias = 'data'
    data_name = 'data'
    detector[alias] = h5py.SoftLink('/entry/data/%s'%data_name)

    for m_key in sorted(self.hierarchy.keys()):
      module_num = int(m_key.lstrip('c'))
      m_name = 'AXIS_D0M%d'%module_num
      module_vector = self.hierarchy[m_key].local_origin
      self.create_vector(transformations, m_name, 0.0, depends_on='AXIS_D0', equipment='detector', equipment_component='detector_module',transformation_type='rotation', units='degrees', vector=(0., 0., -1.), offset = module_vector.elems, offset_units = 'mm')
      for a_key in sorted(self.hierarchy[m_key].keys()):
        asic_num = int(a_key.lstrip(m_key).lstrip('a'))
        a_name = 'AXIS_D0M%dA%d'%(module_num, asic_num)
        asic = self.hierarchy[m_key][a_key]
        asic_vector = asic['local_origin']
        fast = asic['local_fast']
        slow = asic['local_slow']

        self.create_vector(transformations, a_name, 0.0, depends_on=m_name, equipment='detector', equipment_component='detector_asic',
                           transformation_type='rotation', units='degrees', vector=(0., 0., -1.), offset = asic_vector.elems, offset_units = 'mm')

        asicmodule = detector.create_group(array_name+'M%dA%d'%(module_num,asic_num))
        asicmodule.attrs['NX_class'] = 'NXdetector_module'
        asicmodule.create_dataset('data_origin', (2,), dtype='i', data=[int(asic['min_ss']), int(asic['min_fs'])])
        asicmodule.create_dataset('data_size', (2,), dtype='i', data=[int(asic['max_ss'])-int(asic['min_ss']),
                                                                      int(asic['max_fs'])-int(asic['min_fs'])])

        self.create_vector(asicmodule, 'fast_pixel_direction',pixel_size,
                           depends_on=transformations.name+'/'+a_name,
                           transformation_type='translation', units='mm', vector=fast.elems, offset=(0. ,0., 0.))
        self.create_vector(asicmodule, 'slow_pixel_direction',pixel_size,
                           depends_on=transformations.name+'/'+a_name,
                           transformation_type='translation', units='mm', vector=slow.elems, offset=(0., 0., 0.))
    f.close()

if __name__ == '__main__':
  import sys
  nexus_helper = jf4m_geom2nexus(sys.argv[1:])
  nexus_helper.create_nexus_master_file()
