#!python
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import sys
import argparse
import warnings

from prv_accountant import Accountant, other_accountants, privacy_random_variables

try:
    sys.skip_tf_privacy_import = True
    from tensorflow_privacy.privacy.analysis import rdp_accountant, gdp_accountant
    TF_PRIVACY_AVAILABLE = True
except ImportError:
    TF_PRIVACY_AVAILABLE = False


def arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Compute DP epsilon for a set of training hyper-params")
    parser.add_argument(
        "-p", "--sampling-probability", type=float, required=True,
        help="Probability of a user being sampled into a batch. "
             "(This is very often batch_size*max_samples_per_user/num_samples)"
    )
    parser.add_argument(
        "-s", "--noise-multiplier", type=float, required=True,
        help="A parameter which governs how much noise is added."
    )
    parser.add_argument(
        "-i", "--num-compositions", type=int, required=True,
        help="The number of compositions at which epsilon is computed."
    )
    parser.add_argument(
        "-d", "--delta", type=float, required=True,
        help="The target delta in the eps-delta DP framework"
    )
    parser.add_argument(
        "-v", "--verbose", action="store_true", default=None,
        help="Increase verbosity"
    )
    return parser


def main() -> int:
    args = arg_parser().parse_args()

    prv = privacy_random_variables.PoissonSubsampledGaussianMechanism(
        sampling_probability=args.sampling_probability,
        noise_multiplier=args.noise_multiplier
    )

    methods = {}

    prv_acc = Accountant(
        noise_multiplier=args.noise_multiplier,
        sampling_probability=args.sampling_probability,
        delta=args.delta,
        max_compositions=args.num_compositions,
        verbose=args.verbose,
        eps_error=0.1
    )
    methods["PRV Accountant"] = lambda steps: prv_acc.compute_epsilon(steps)

    rdp_acc = other_accountants.RDP(prv=prv, delta=args.delta)
    methods["RDP Accountant"] = lambda steps: rdp_acc.compute_epsilon(steps)

    if TF_PRIVACY_AVAILABLE:
        def compute_eps_gdp(steps):
            batch_size = 1000
            n = batch_size/args.sampling_probability
            epoch = steps*args.sampling_probability
            eps = gdp_accountant.compute_eps_poisson(epoch, args.noise_multiplier, n, batch_size, args.delta)
            return 0.0, eps, float('inf')
        methods["GDP Accountant"] = compute_eps_gdp
    else:
        warnings.warn("Install TF privacy for more accountants.")
 
 
    for name, compute_eps in methods.items():
        try:
            eps_lower, eps_est, eps_upper = compute_eps(args.num_compositions)
            print(f"{name}:\t\teps_lower = {eps_lower:6.3} eps_estimate = {eps_est:6.3}, eps_upper = {eps_upper:6.3} ")
        except Exception as e:
            print(f"{name}:\t\tn/a")
    return 0


if __name__ == "__main__":
    sys.exit(main())
