#!/usr/bin/env python

"""Various utilities to manipulate N-body data and files containing such data"""

#print 'In Nbody_Utils.py\n'

from numpy import *
from os import environ as osenv
import sys, os.path, tempfile, subprocess
from pylab import load as pylab_load
from pylab import save as pylab_save

dens_n_neib = 6

nemo_path=os.path.expanduser("~") + "/AMAS/Nemo"
nemo_bin_path=nemo_path + "/bin"
nemo_obj_path=nemo_path + "/obj"

osenv["NEMOOBJ"]= nemo_obj_path
osenv["BTRPATH"]= nemo_obj_path + "/bodytrans"

#----------------------------------------------------------------------
def merge_same_pos(m,x,y,z,vx,vy,vz,eps_merg=0.0,verb=0) :
#----------------------------------------------------------------------
    """
    Merge particles that are at the same position or very close to each other
    (within eps_merg). Returns new values of m,x,y,z,vx,vy,vz and a list of
    list of the original number of the particle, iorigin. For instance, if
    iorigin[10] is [2,7,34], this means that particle 10 is the result of the
    merger of particles 2,7 and 34.
    """
    
    Npart=len(m)
    iPart=array(range(1,Npart+1))
    
    # Compute distance to centre and sort
    if verb > 0 :
        sys.stderr.write('>> Sorting particles ...')
    R=(x*x+y*y+z*z)**0.5
    isort=R.argsort()
    
    (m,x,y,z,vx,vy,vz,iPart) = map(
        (lambda x: x[isort]),
        [m,x,y,z,vx,vy,vz,iPart])
    if verb > 0 :
        sys.stderr.write(' done.\n')
    del R

    # Determine distances between successive particles
    d2_to_next = (x[1:]-x[:-1])**2 + (y[1:]-y[:-1])**2 + (z[1:]-z[:-1])**2
    d2_to_next = concatenate((d2_to_next,[1.e30]))

    # Loop over particles and merge them if necessary

    if verb > 0 :
        sys.stderr.write('>> Checking for close neighbours and merging them ...')
    
    ipre_mrg_dn=0
    ipre_mrg_up=0
    Npost_mrg=0
    iorigin=[] # Will contain the list if the ID numbers of the original particles
    while ipre_mrg_dn<=Npart-1:
        ipre_mrg_up=ipre_mrg_dn
        #print >> sys.stderr, '*',ipre_mrg_up,d2_to_next[ipre_mrg_up]
        while (ipre_mrg_up<Npart-1 and d2_to_next[ipre_mrg_up]<=eps_merg) :
            #print >> sys.stderr, '******',ipre_mrg_up,d2_to_next[ipre_mrg_up]
            ipre_mrg_up+=1
            
        if (ipre_mrg_up==ipre_mrg_dn) :
            m[Npost_mrg]  = m[ipre_mrg_up]
            x[Npost_mrg]  = x[ipre_mrg_up]
            y[Npost_mrg]  = y[ipre_mrg_up]
            z[Npost_mrg]  = z[ipre_mrg_up]
            vx[Npost_mrg] = vx[ipre_mrg_up]
            vy[Npost_mrg] = vy[ipre_mrg_up]
            vz[Npost_mrg] = vz[ipre_mrg_up]
            iorigin.append([iPart[ipre_mrg_dn]])
        else :
            #print ipre_mrg_dn,ipre_mrg_up
            mtot = sum(m[ipre_mrg_dn:ipre_mrg_up+1])
            m[Npost_mrg] = mtot
            x[Npost_mrg] = sum(x[ipre_mrg_dn:ipre_mrg_up+1])/mtot
            y[Npost_mrg] = sum(y[ipre_mrg_dn:ipre_mrg_up+1])/mtot
            z[Npost_mrg] = sum(z[ipre_mrg_dn:ipre_mrg_up+1])/mtot
            vx[Npost_mrg] = sum(vx[ipre_mrg_dn:ipre_mrg_up+1])/mtot
            vy[Npost_mrg] = sum(vy[ipre_mrg_dn:ipre_mrg_up+1])/mtot
            vz[Npost_mrg] = sum(vz[ipre_mrg_dn:ipre_mrg_up+1])/mtot
            iorigin.append(iPart[ipre_mrg_dn:ipre_mrg_up+1])
            iorigin[-1].sort()
        Npost_mrg+=1
        ipre_mrg_dn=ipre_mrg_up+1

    if verb > 0 :
        sys.stderr.write(' done. %i particles left\n' % Npost_mrg)

    return ( m[0:Npost_mrg],x[0:Npost_mrg],y[0:Npost_mrg],z[0:Npost_mrg],
             vx[0:Npost_mrg],vy[0:Npost_mrg],vz[0:Npost_mrg], iorigin )


#----------------------------------------------------------------------
def run_nemo_pipe(pipe_str, infile=None, outfile=None, verb=0) :
#----------------------------------------------------------------------
    def bla(command) :
        if verb>0 :
            sys.stderr.write('>> run_nemo_pipe: running %s\n' % command)

    sys.stderr.write('>> command : ' + pipe_str + '\n')
    sys.stderr.write('>> files : %s %s \n' % (infile,outfile))
    commands=pipe_str.split('|')
    commands=map((lambda x :  nemo_bin_path + '/' + x.strip() ), commands)
    n_commands=len(commands)
    if n_commands < 1 :
        return
    elif n_commands == 1 : # single command
        bla(commands[0])
        p=subprocess.Popen(commands[0].split(),stdin=infile,stdout=outfile)
    else :
        bla(commands[0])
        p=subprocess.Popen(commands[0].split(),stdin=infile,stdout=subprocess.PIPE)
        for i in range(1,n_commands-1) :
            p_prev = p
            bla(commands[i])
            p=subprocess.Popen(commands[i].split(),stdin=p_prev.stdout,stdout=subprocess.PIPE)

        p_prev = p
        bla(commands[-1])
        p=subprocess.Popen(commands[-1].split(),stdin=p_prev.stdout,stdout=outfile)
    p.wait()

#----------------------------------------------------------------------
def write_arrays(file_name,arrays,fmt='%12.5e') :
#----------------------------------------------------------------------
    fmt = fmt+' '
    f = open(file_name,mode="w")
    n_arrays=len(arrays)
    n_elem=len(arrays[0])
    fmt_str=((fmt+' ') * n_arrays) + '\n'
    for i in range(0,n_elem) :
        for j in  range(0,n_arrays) :
            f.write(fmt % arrays[j][i])
        f.write('\n')
                  
    f.close()
    
#----------------------------------------------------------------------
def nemo_hackdens(m,x,y,z,neib=6,sort=None,verb=0) :
#----------------------------------------------------------------------
    
    """Wrapper around Nemo's hackdens.
    Local density estimator using tree algorithm"""
    
    verb_nemo = "f"
    if (verb>1) : verb_nemo="t"
    
    Npart=len(m)

    # Write data into (temp) ascii file

    (fhandle,name_ascii_in_file)=tempfile.mkstemp()
    os.close(fhandle)
    #ascii_in_file=open(name_ascii_in_file)
    #pylab_save(name_ascii_in_file,array[m,x,y,z].transpose(),fmt='%17.10e')

    write_arrays(name_ascii_in_file,(m,x,y,z),fmt='%17.10e')

    # Command to convert to nemo format
    nemo_command = 'tabtos in=- out=- block1=mass,pos nbody=%s' % Npart
    
    # Command to sort (if required)
    if type(sort)==str :
        if sort=="r" or sort=="R" :
            sort="x*x+y*y+z*z" 
            nemo_command = nemo_command + \
                         " | snapcenter  in=- out=- | snapsort  in=- out=- rank=%s" % sort
        else :
            sys.stderr.write("!!! nemo_hackdens : don't know how to sort according to %s !!!\n" % sort )
            sys.exit(1)

    
    # Command to compute density
    nemo_command = nemo_command  + " | hackdens   in=- out=- write_at_phi=t neib=%s verbose=%s" % (neib, verb_nemo)

    # Command to convert back to ascii
    nemo_command = nemo_command  + r" | snapprint in=- options=phi format=%12.5e"
    
    # Run the whole pipe and cross fingers

    ascii_in_file=open(name_ascii_in_file,mode="r")
    
    (fhandle,name_ascii_out_file)=tempfile.mkstemp()
    os.close(fhandle)
    ascii_out_file=open(name_ascii_out_file,mode="w")

    run_nemo_pipe(nemo_command,ascii_in_file,ascii_out_file,verb=1)
    ascii_in_file.close()
    os.remove(name_ascii_in_file)
    ascii_out_file.close()

    # Read dat from ascii file
    dens = pylab_load(name_ascii_out_file, usecols=[0], unpack=True)
    os.remove(name_ascii_out_file)

    # return density array
    return dens


#----------------------------------------------------------------------
def Nbody2Spherical(m,x,y,z,vx,vy,vz,R_centre=None,V_centre=None,expo_dens=1.0,verb=0) :
    # Returns (R, vrad, vtan)
    # If R_centre and/or V_centre is not given, returns ((R, vrad, vtan),(R_centre,V_centre))
    # (respecting sign of vrad, sign of vtan is > 0) 
#----------------------------------------------------------------------
    tiny=1e-20
    centre_given=True
    if (R_centre==None or V_centre==None) :
        centre_given= False
        dens = nemo_hackdens(m,x,y,z,neib=dens_n_neib,verb=verb)
        #i_max = dens.argmax()

        # We use the density-averaged position (and velocity) as effective centre
        # See von Hoerner 1963, Hut & Casertano 1985

        if expo_dens != 1.0 :
            weight = dens**expo_dens
        else :
            weight = dens
            
        norm = 1/weight.sum()
        R_centre = norm*array([sum(weight* x),sum(weight* y),sum(weight* z)])
        V_centre = norm*array([sum(weight*vx),sum(weight*vy),sum(weight*vz)])

    x  =  x - R_centre[0]
    y  =  y - R_centre[1]
    z  =  z - R_centre[2]
    vx = vx - V_centre[0]
    vy = vy - V_centre[1]
    vz = vz - V_centre[2]

    R = sqrt(x*x+y*y+z*z)
    V2 = vx*vx+vy*vy+vz*vz

    vrad = where(R>tiny, (x*vx+y*vy+z*vz)/R, 0.0) # "where" is from numpy
    vtan = sqrt(V2-vrad*vrad)

    if centre_given :
        return (R, vrad, vtan)
    else :
        return ((R, vrad, vtan), (R_centre, V_centre))
    
if __name__=="__main__" :
    print 'HELLO'
    
