from __future__ import absolute_import, division, print_function
from six.moves import range
pdb1o3u_dr = [0.0061325298784202153, 0.014879413715546607, 0.017457194557825101, 0.018347391914744521, 0.018398855219007079, 0.019202048799521983, 0.019517965718312932, 0.019562016426123482, 0.021047905154718643, 0.021671092208930164, 0.022904601288927438, 0.023111327339836105, 0.023439296198896946, 0.023782588884198611, 0.024360809492051883, 0.024515929405888277, 0.025274728388404489, 0.026095051574393579, 0.026837723677846946, 0.026917759438200881, 0.027349736270477133, 0.028351625791702958, 0.028528074238257928, 0.030577840717824302, 0.03100089275285316, 0.03112568087237404, 0.031196276876889948, 0.032134801342764877, 0.032287519340335968, 0.033113837731673708, 0.034930340777672486, 0.034975886476864164, 0.035510572565445493, 0.035782851028263528, 0.036108670669189075, 0.037486391145234693, 0.037547692967114507, 0.038138641134861782, 0.03872782876077327, 0.039360484863600227, 0.039392956660847744, 0.039455515940405715, 0.039781340664745701, 0.039938376505919616, 0.040100076091755511, 0.040972595249253431, 0.041595718450016216, 0.043045650559721503, 0.043067786492559622, 0.044158755723266538, 0.044365949459101738, 0.044761432129686696, 0.044875549224036214, 0.045151115013577436, 0.045553422325070245, 0.046561947659451541, 0.047586479988876579, 0.047736310212856022, 0.04826721968382771, 0.048742433776767104, 0.048776209457733317, 0.048889210699902784, 0.049179408268897516, 0.049254075015550812, 0.049747978073519786, 0.049851218522625043, 0.050121003957410849, 0.052370390519363957, 0.052607188001786293, 0.053015180575850883, 0.053076746263787999, 0.053107003914104216, 0.053120740618883784, 0.053355109544567625, 0.05347100039410891, 0.053958175899696689, 0.055056745742942476, 0.05535265209920992, 0.055611797361610649, 0.055897102373536435, 0.057285637081057168, 0.05861290495276298, 0.058966726767379185, 0.05896707411020375, 0.059118830002659092, 0.060151738524048516, 0.061263676861175377, 0.062484285503780213, 0.063026146839991193, 0.063080620883205679, 0.064076431743227372, 0.064520739855443876, 0.064693458332623049, 0.065436796068400041, 0.06630418395152228, 0.067053940499318274, 0.06741498540997197, 0.067556229914308583, 0.06766272650165267, 0.06834237769771867, 0.06854205246447688, 0.068818579020083856, 0.06899158844304025, 0.069096558635811142, 0.069495653724970935, 0.069696237823830623, 0.070465116470640313, 0.070997583025903963, 0.071526766022421015, 0.071939964958455199, 0.072184015317283315, 0.072484635696002539, 0.072621183167513828, 0.072647446438194749, 0.073082327027999869, 0.073263757528515855, 0.073998575317184595, 0.074134930438007396, 0.074320443516861567, 0.074851796923213601, 0.074860021988163994, 0.075295972295828892, 0.075815162319241416, 0.076328817847950675, 0.076863862369519684, 0.077184062694443972, 0.077387640284410392, 0.077388727427102857, 0.078316572059171058, 0.078556150787683812, 0.07904456839288844, 0.079588664991238886, 0.079594528560279321, 0.080459806553386831, 0.080529272315186542, 0.080620965861698576, 0.080734244488689988, 0.081518243573305019, 0.081651406035016083, 0.081699588116050381, 0.08193771621332567, 0.082357231258151628, 0.082461002361030517, 0.082797678394278554, 0.083219115369456517, 0.083512604279187136, 0.084602127297362287, 0.084676209486416201, 0.084748709650872805, 0.08510872157320673, 0.085262550098189546, 0.085435420755590893, 0.086032322538394268, 0.08655383328671569, 0.086889683177808236, 0.087447772217258218, 0.087700553990416874, 0.088051498299336617, 0.088433231854574232, 0.088641099812851454, 0.088893689307506588, 0.090186749735785984, 0.090212911859378186, 0.090287226163894896, 0.090310933303622207, 0.090496334369620596, 0.090742713947269196, 0.090749638333221236, 0.091054230497617586, 0.091653335822703977, 0.09175750868748847, 0.092008602703168546, 0.092161350886393253, 0.092290938021188026, 0.092678904783118271, 0.09277847008352666, 0.093032083806438612, 0.093360114960586701, 0.093608366534905479, 0.094271897183381276, 0.094312652235312738, 0.094463487948746747, 0.094510868662994268, 0.094757009596033739, 0.094858155629457194, 0.095555305037797675, 0.095744739463440809, 0.09586623913758141, 0.09586757295157021, 0.096093491329116196, 0.096341625470585585, 0.0963818759959537, 0.096680856160791528, 0.096701606896714754, 0.097049331859364846, 0.097127150387850747, 0.097272787333873206, 0.097371053069005925, 0.097512880119137169, 0.097923664225673487, 0.097935507981030248, 0.10013700550345515, 0.10053554165235783, 0.10175168442236039, 0.10182349518849756, 0.10206277186398476, 0.10209228689867172, 0.10326540711891594, 0.10394675679116201, 0.10431761585030459, 0.10455783326212861, 0.10487753762127902, 0.10511350432930949, 0.10596546142816184, 0.10617026269990119, 0.10622989741174485, 0.1065210327334459, 0.1078172497582031, 0.10831569140983882, 0.10881388219187864, 0.10936528470186996, 0.10950344195935545, 0.10954260851171445, 0.11016138889328521, 0.11133134601984981, 0.11295731552665068, 0.11344612632766847, 0.11401400986286635, 0.11443466621313529, 0.11445184617692998, 0.11452139720652732, 0.11515184957406122, 0.11610828981284019, 0.1162649952387203, 0.11628569052244492, 0.1175366955228144, 0.11784807602321434, 0.11851650639603323, 0.11979168238827603, 0.1205379204722317, 0.12069997196103147, 0.1214473205594099, 0.12236952233978633, 0.12237918364507622, 0.12418876678805034, 0.12532625655206739, 0.12552593681565608, 0.12564586923813539, 0.12620903831041427, 0.12763521098691358, 0.12780071513211344, 0.12887291851414365, 0.12896700463830343, 0.12912804088389132, 0.13003906746536537, 0.13296011405540931, 0.13368447245242243, 0.13477101192228166, 0.13524785410428858, 0.13747647235566562, 0.14004875556781965, 0.14279278714994253, 0.14283729491890418, 0.1448699411048997, 0.14757804896475299, 0.1486198628089869, 0.15257339793955169, 0.15382095296350662, 0.1562340831616145, 0.15655366418276007, 0.15904854573662416, 0.16411997033733039, 0.16686950141907164, 0.16804090492827642, 0.17136050410641243, 0.17209586332786453, 0.17669151165248831, 0.17747728936627932, 0.18384286105262251, 0.18421038573449164, 0.20103588821356649, 0.20321445476002511, 0.21101721288277717, 0.22285904159750272, 0.27168157452964675, 0.3367010537431549, 0.4139294542484368, 0.41467358635888341, 1.1185901341574491, 1.5236360562420754]
import math
from scitbx.array_family import flex
from scitbx import lbfgs,lbfgsb

class rayleigh(object):
  """  Code is from Billy Poon
  =============================================================================
  Class models a 1-d Rayleigh distribution using one parameter, sigma.

              x                x^2
    pdf = --------- exp(- ------------)
           sigma^2         2 sigma^2

                        x^2
    cdf = 1 - exp(- -----------)
                     2 sigma^2

  The derivative of the cdf with respect to sigma is,

      d(cdf)          x^2               x^2             x
    ---------- = - --------- exp( - -----------) = - ------- pdf
     d(sigma)       sigma^3          2 sigma^2        sigma

  -----------------------------------------------------------------------------
  """

  def set_parameters(self,p):
    assert(len(p) == 1)
    self.sigma = float(p[0])

  def estimate_parameters_from_cdf(self):
    """
    Function estimates the parameter values based on the data (cdf)
    """
    # sigma is the mode of the distribution
    # approximate with the median (cdf = 0.5)
    midpoint = None
    for i in range(len(self.x_data)):
      if (self.y_data[i] > 0.5):
        midpoint = i
        break
    if (midpoint is None):
      midpoint = len(self.x_data) - 1
    self.sigma = self.x_data[midpoint]
    return flex.double([self.sigma])

  def pdf(self,x=None):
    """
    Function returns the probability density function at x
    """
    x_sigma = x/self.sigma
    f = (x_sigma/self.sigma)*math.exp(-0.5*x_sigma*x_sigma)
    return f

  def cdf(self,x=None):
    """
    Function returns the cumulative distribution function at x
    """
    x_sigma = x/self.sigma
    f = 1.0 - math.exp(-0.5*x_sigma*x_sigma)
    return f

  def d_cdf_d_sigma(self,x=None):
    """
    Function returns the derivative of the cdf at x with respect to the
    standard deviation
    """
    df = -(x/self.sigma)*self.pdf(x=x)
    return df

  def cdf_gradients(self,x=None):
    """
    Function returns a flex.double containing all derivatives
    """
    return flex.double([self.d_cdf_d_sigma(x)])

class fit_cdf(rayleigh):
  """
  =============================================================================
  Class fits a distribution according to its cumulative distribution function

  Arguments:
    x_data - measured property (list)
    y_data - fraction of points with measured property (cdf) (list)

  Useful accessible attributes:
    self.x - the final parameters (list)
  -----------------------------------------------------------------------------
  """
  def __init__(self,x_data=None,y_data=None,minimizer_type=None):
    # setup data
    assert(len(x_data) == len(y_data))
    self.n = len(x_data)
    self.x_data = flex.double(x_data)
    self.y_data = flex.double(y_data)

    # intialize distribution with guess
    self.x = self.estimate_parameters_from_cdf()

    self.l = flex.double([1.E-8])
    self.u = flex.double([0.])
    self.nbd = flex.int([1])
    self.number_of_lbfgs_iterations = -1
    self.number_of_function_evaluations = -1

    # optimize parameters
    if minimizer_type=="lbfgs":
      self.minimizer = lbfgs.run(target_evaluator=self)
    elif minimizer_type=="lbfgsb":
      self.minimizer = self.run_lbfgsb()

    # set optimized parameters
    self.set_parameters(self.x)

  def run_lbfgsb(O, iprint=-1):
    n = O.x.size()
    minimizer = lbfgsb.minimizer(
      n=n,
      m=5,
      l=O.l,
      u=O.u,
      nbd=O.nbd,
      enable_stp_init=True,
      factr=1.0e+7,
      pgtol=1.0e-5,
      iprint=iprint)

    f,g = O.compute_functional_and_gradients()
    while True:
      if (minimizer.process(O.x, f, g)):
        f, g = O.compute_functional_and_gradients()
      elif (minimizer.requests_stp_init()):
        new_stp = O.adjust_stp(
          stp=minimizer.relative_step_length_line_search(),
          csd=minimizer.current_search_direction())
        minimizer.set_relative_step_length_line_search(value=new_stp)
      elif (minimizer.is_terminated()):
        O.callback_after_step_no_counting(suffix=" FINAL")
        break
      else:
        O.callback_after_step(minimizer=None)
    return minimizer

  def adjust_stp(self,stp,csd):
    assert stp > 0.  #tests my understanding of the lbfgsb implementation
    assert csd.size() == self.x.size() == 1
    large_shift = 0.9*abs(stp*csd[0])
    max_stp = abs(large_shift/csd[0])
    return min(max_stp,stp)

  def callback_after_step(O, minimizer, suffix=""):
    O.number_of_lbfgs_iterations += 1
    O.callback_after_step_no_counting(suffix=suffix)

  def callback_after_step_no_counting(O, suffix=""):
    return

  def compute_functional_and_gradients(self):
    self.number_of_function_evaluations += 1
    self.set_parameters(self.x)
    predicted = flex.double(self.n)
    for i in range(self.n):
      predicted[i] = self.cdf(x=self.x_data[i])
    difference = predicted - self.y_data

    # target function for minimization is sum of rmsd
    f = flex.sum(flex.sqrt(difference*difference))
    gradients = flex.double(len(self.x))
    for i in range(self.n):
      g_i = self.cdf_gradients(x=self.x_data[i])
      for j in range(len(self.x)):
        gradients[j] = gradients[j] + difference[i]*g_i[j]
    gradients = 2.0*gradients
    print("cycle sigma", self.sigma)
    return f,gradients

if __name__=="__main__":
  print("""Test program for LBFGS options.  Fit a function to the test data using two methods: lbfgs & lbfgsb.""")
  xdata = flex.double(pdb1o3u_dr[:159])
  ydata = flex.double(len(xdata))
  for i in range(len(ydata)):
      ydata[i] = float(i)/float(len(pdb1o3u_dr))
  print("The dataset consists of %d (x,y) pairs to be fitted."%xdata.size())
  print("The functional has a single variable parameter to be optimized.  Expected value is about 0.072277")
  print("Trying lbfgs...")
  sigma = fit_cdf(xdata,ydata,"lbfgs").sigma
  print("Answer from lbfgs...",sigma)
  print("Trying lbfgsb...")
  sigma = fit_cdf(xdata,ydata,"lbfgsb").sigma
  print("Answer from lbfgsb...",sigma)
