import sys
sys.path.append("..")
import emoo
import numpy as np
import pylab as pl
from neuron  import h

target_n0 = 0 # we want no spike before the stimulation
target_n1 = 3 # we want exactly three spikes after the stimulation

# Load the NEURON libraries
h.load_file("stdrun.hoc")

# We want variable time step
h("""
    cvode.active(1)
""")

# Simulation for 450 ms
h.tstop = 450

# Create the Soma
soma = h.Section()
soma.L = 3.18
soma.diam = 10

# Insert Standard Hodgkin and Huxley (Na, K and Leak)
soma.insert('hh')

# Put an IClamp at the Soma
stim = h.IClamp(0.5, sec=soma)
stim.delay = 100 # Stimulus stat
stim.dur = 10 # Stimulus length
stim.amp = 0.01 # strength of Current injection

# Create the Recording Vectors for time and voltage
vec_t = h.Vector()
vec_v = h.Vector()
vec_t.record(h._ref_t) # Time
vec_v.record(soma(0.5)._ref_v) # Voltage
          
# Define the variables and their lower and upper search bounds
variables = [["na", 0, 2],
             ["k", 0, 0.3],
             ["leak", 0.000003, 0.00003]]

# Define the list of objectives
objectives = ["spikes_error1", "spikes_error2"]

# This is the function which is going to be minimized
def func_to_optimize(parameters):
    
    # Send the parameters to NEURON    
    soma.gnabar_hh = parameters['na']
    soma.gkbar_hh = parameters['k']
    soma.gl_hh = parameters['leak']
    
    # Run the Simulation
    h.v_init = -65
    h.init()
    h.run()
    
    # Count the spikes (there are surey better ways)
    v = np.array(vec_v)
    
    # put everything to threshold that is below threshold or has a negative slope
    ind = np.where(v < 20)
    v[ind] = 20
    ind = np.where(np.diff(v) < 0)
    v[ind] = 20
    
    # the remaining negative slopes are where the spikepeaks are
    ind = np.where(np.diff(v) < 0)
    spike_times = np.array(vec_t)[ind]
    
    n0 = len(np.where(spike_times < 100)[0]) # spikes before stimulus onset
    n1 = len(np.where(spike_times > 100)[0]) # spikes after stimulus onset
    
    # return the two errors
    return dict({"spikes_error1": (target_n0-n0)**2, "spikes_error2": (target_n1-n1)**2})

# After each generation this function is called
def checkpopulation(population, columns, gen):
    
    # we want both spike errors to be minimal at the end
    i = np.argmin(population[:, columns["spikes_error1"]]**2 + 
                  population[:, columns["spikes_error2"]]**2)
    
    print "Generation %d:"%gen, \
        "Best spikes_error1; ", population[i, columns["spikes_error1"]],\
        "Best spikes_error2", population[i, columns["spikes_error2"]]
    
    
# Initiate the Evlutionary Multiobjective Optimization
emoo = emoo.Emoo(N = 100, C = 200, variables = variables, objectives = objectives)
# Parameters:
# N: size of population
# C: size of capacity 

emoo.setup(eta_m_0 = 20, eta_c_0 = 20, p_m = 0.5)
# Parameters:
# eta_m_0, eta_c_0: defines the initial strength of the mution and crossover parameter (large values mean weak effect)
# p_m: probabily of mutation of a parameter (holds for each parameter independently)

emoo.get_objectives_error = func_to_optimize
emoo.checkpopulation = checkpopulation

emoo.evolution(generations = 20)

# Look at the Result!
# this should only be done by the master
if emoo.master_mode:
    population = emoo.getpopulation_unnormed() # get the unnormed population
    columns = emoo.columns # get the columns vector
    
    # we want both spike errors to be minimal at the end
    i = np.argmin(population[:, columns["spikes_error1"]]**2 + 
                  population[:, columns["spikes_error2"]]**2)
    
    print "Best spikes_error1; ", population[i, columns["spikes_error1"]],\
          "Best spikes_error2", population[i, columns["spikes_error2"]]
        
    print "Na:", population[i, columns['na']]
    print "K:", population[i, columns['k']]
    print "leak:", population[i, columns['leak']]
    
    soma.gnabar_hh = population[i, columns['na']]
    soma.gkbar_hh = population[i, columns['k']]
    soma.gl_hh = population[i, columns['leak']]

    h.v_init = -65
    h.init()
    h.run()

    pl.plot(vec_t, vec_v, linewidth=3, color='r')
    pl.xlabel("Time (ms)")
    pl.ylabel("V (mV)")
    pl.show()