# It could be represented more compactly using matrix operations, but I figured it might be easier to read/port this way.
# The er parameter are the 6 relative exchange rates, for HKY this vector looks like [1,kappa,1,1,kappa,1].

# With matrix code in R, all those `ac = ` lines become a single line in the assmebleQ function, and in the AtSite function they get wrapped into a nested loop.

assembleQ = function(er,pi) {
  ac = er[1] * pi[2]
  ag = er[2] * pi[3]
  at = er[3] * pi[4]
  cg = er[4] * pi[3]
  ct = er[5] * pi[4]
  gt = er[6] * pi[4]
  ca = er[1] * pi[1]
  ga = er[2] * pi[1]
  ta = er[3] * pi[1]
  gc = er[4] * pi[2]
  tc = er[5] * pi[2]
  tg = er[6] * pi[3]
  unscaled = c(ac,ag,at,cg,ct,gt,ca,ga,ta,gc,tc,tg)
  mu = pi[1]*(ac+ag+at) + pi[2]*(ca+cg+ct) + pi[3]*(ga+gc+gt) + pi[4]*(ta+tc+tg)
  return(unscaled/mu)
}

expectedNumberMutationsAtSite = function(Q,pi,tree_length) {
  ac = tree_length * Q[1] * pi[1]
  ag = tree_length * Q[2] * pi[1]
  at = tree_length * Q[3] * pi[1]
  cg = tree_length * Q[4] * pi[2]
  ct = tree_length * Q[5] * pi[2]
  gt = tree_length * Q[6] * pi[3]
  ca = tree_length * Q[7] * pi[2]
  ga = tree_length * Q[8] * pi[3]
  ta = tree_length * Q[9] * pi[4]
  gc = tree_length * Q[10] * pi[3]
  tc = tree_length * Q[11] * pi[4]
  tg = tree_length * Q[12] * pi[4]
  expected.per.site = c(ac+ca,ag+ga,at+ta,cg+gc,ct+tc,gt+tg)
  return(sum(expected.per.site))
}

expectedNumberMutationsAllSites = function(er,pi,tree_length,site_rates,ancestral_sequence) {
  Q = assembleQ(er,pi)
  nsites = length(site_rates)
  states = c("A","C","G","T")
  per_site = numeric(nsites)
  for (i in 1:nsites) {
    indicator_vec = rep(0,4)
    ancestral_state_at_site = which(states == ancestral_sequence[i])
    indicator_vec[ancestral_state_at_site] = 1
    per_site[i] = expectedNumberMutationsAtSite(Q,indicator_vec,tree_length*site_rates[i])
  }
  return(sum(per_site))
}
