#!python

import os
import json
import random
import shutil
import argparse
from glob import glob

from os.path import join as pjoin

from tqdm import tqdm
from termcolor import colored

import textworld
import textworld.agents
import textworld.gym
from textworld.envs import PddlEnv

from alfworld.agents.environment.alfred_tw_env import AlfredInfos, AlfredDemangler, AlfredExpert, AlfredExpertType

from alfworld.info import ALFWORLD_DATA
from alfworld.agents.utils.misc import add_task_to_grammar


TASK_TYPES = {
    1: "pick_and_place_simple",
    2: "look_at_obj_in_light",
    3: "pick_clean_then_place_in_recep",
    4: "pick_heat_then_place_in_recep",
    5: "pick_cool_then_place_in_recep",
    6: "pick_two_obj_and_place"
}


def gen_game_files(args):
    game_logic = {
        "pddl_domain": open(os.path.expandvars(args.domain)).read(),
        "grammar": open(os.path.expandvars(args.grammar)).read()
    }

    rng = random.Random(args.seed)

    print("Checking for solvable games...")

    task_types = list(TASK_TYPES.values())

    # Setup environment
    request_infos = textworld.EnvInfos(admissible_commands=True, extras=["gamefile"])
    env = PddlEnv(request_infos)
    env = AlfredDemangler(env, shuffle=False)
    env = AlfredInfos(env)
    env = AlfredExpert(env, expert_type=args.expert_type)

    nb_solvable = 0
    all_traj_data = sorted(glob(args.data_path + '/**/traj_data.json', recursive=True))[args.start:args.end]
    pbar = tqdm(all_traj_data)

    counters = {
        "/train/": 0,
        "/valid_seen/": 0,
        "/valid_unseen/": 0,
    }

    def log(info):
        if args.verbose:
            pbar.write(info)

    for json_path in pbar:
        root_ref = os.path.dirname(json_path)
        pbar.set_description(root_ref)
        pbar.set_postfix_str("")
        pddl_path = pjoin(root_ref, 'initial_state.pddl')

        # Skip if no PDDL file
        if not os.path.exists(pddl_path):
            log(colored(f"Skipping {root_ref}, PDDL file is missing", "yellow"))
            continue

        if 'movable' in root_ref or 'Sliced' in root_ref:
            log(colored(f"Skipping {root_ref}, Movable & slice trajs not supported", "yellow"))
            continue

        # Get goal description
        with open(json_path, 'r') as f:
            traj_data = json.load(f)

        # Check for any task_type constraints
        if not traj_data['task_type'] in task_types:
            log(colored(f"Skipping {root_ref}, Skipping task type: {traj_data['task_type']}", "yellow"))
            continue

        # Prepare output directory
        root_out = root_ref.replace(args.data_path, args.save_path)
        os.makedirs(root_out, exist_ok=True)

        game_file_path_ref = pjoin(root_ref, "game.tw-pddl")
        game_file_path_out = pjoin(root_out, "game.tw-pddl")

        grammar = add_task_to_grammar(game_logic['grammar'], traj_data, goal_desc_human_anns_prob=args.goal_desc_human_anns_prob, rng=rng)

        # 1. Check if game file exists
        if not os.path.exists(game_file_path_out):
            if args.save_path != args.data_path:
                shutil.copy(json_path, json_path.replace(args.data_path, args.save_path))

            if os.path.exists(game_file_path_ref):
                log(colored(f"Copying {game_file_path_ref} to {game_file_path_out}", "cyan"))
                shutil.copy(game_file_path_ref, game_file_path_out)
            else:
                gamedata = {
                    "pddl_domain": game_logic['pddl_domain'],
                    "grammar": grammar,
                    "pddl_problem": open(pddl_path).read(),
                    "solvable": None,
                    "walkthrough": None,
                }
                json.dump(gamedata, open(game_file_path_out, "w"))

        # 2. Check if game is solvable and has a walkthrough
        with open(game_file_path_out, 'r') as f:
            gamedata = json.load(f)

        # 3. Update grammar and re-save game file.
        gamedata['grammar'] = grammar
        json.dump(gamedata, open(game_file_path_out, "w"))

        if gamedata.get('solvable') is not None:
            log(colored(f"Updating: {root_out}", "green"))
            if gamedata['solvable']:
                for k, v in counters.items():
                    if k in root_out:
                        counters[k] += 1
            continue

        # Check if game is solvable (expensive) and save it in the gamedata
        log(colored(f"Solving: {root_out}", "purple"))
        trajectory, error_msg = is_solvable(env, game_file_path_out)

        gamedata['walkthrough'] = trajectory
        gamedata['solvable'] = gamedata['walkthrough'] is not None
        json.dump(gamedata, open(game_file_path_out, "w"))

        expert_steps = -1
        if gamedata['solvable']:
            nb_solvable += 1
            expert_steps = len(gamedata['walkthrough'])
            log(colored(f"Solved: {root_out}", "green"))
        else:
            log(colored(f"Unsolvable: {root_out}. {error_msg}", "red"))

    # Print counters
    print(colored("Dataset:", "cyan"))
    for k, v in counters.items():
        print(colored(f"{k}: {v} games", "cyan"))


# use expert to check the game is solvable
def is_solvable(env, game_file_path,
                random_perturb=False, random_start=10, random_prob_after_state=0.15):
    done = False
    steps = 0
    trajectory = []
    try:
        env.load(game_file_path)
        game_state = env.reset()
        if env.expert_type == AlfredExpertType.PLANNER:
            return game_state["extra.expert_plan"], None

        while not done:
            expert_action = game_state['extra.expert_plan'][0]
            random_action = random.choice(game_state.admissible_commands)

            command = expert_action
            if random_perturb:
                if steps <= random_start or random.random() < random_prob_after_state:
                    command = random_action

            game_state, _, done = env.step(command)
            trajectory.append(command)
            steps += 1
    except SystemExit as e:
        return None, e
    except Exception as e:
        return None, e

    return trajectory, None



if __name__ == "__main__":
    description = "Generate the .tw-pddl files for each traj_data.json file."

    parser = argparse.ArgumentParser(description=description)

    parser.add_argument('--data_path', default=pjoin(ALFWORLD_DATA, "json_2.1.1"),
                        help="Path to ALFWorld data Default: `%(default)s`.")
    parser.add_argument('--save_path',
                        help='Output directory for the .tw-pddl files. Default: same as --data_path.')

    parser.add_argument("--domain",
                        default=pjoin(ALFWORLD_DATA, "logic", "alfred.pddl"),
                        help="Path to a PDDL file describing the domain."
                                " Default: `%(default)s`.")
    parser.add_argument("--grammar",
                        default=pjoin(ALFWORLD_DATA, "logic", "alfred.twl2"),
                        help="Path to a TWL2 file defining the grammar used to generated text feedbacks."
                                " Default: `%(default)s`.")
    parser.add_argument("--expert_type",
                        default="handcoded",
                        choices=["planner", "handcoded"],
                        help="Type of expert to use. Default: `%(default)s`.")
    parser.add_argument("--verbose", action="store_true", help="Print more information.")
    parser.add_argument("--start", type=int, default=0, help="Start index.")
    parser.add_argument("--end", type=int, default=-1, help="End index.")

    parser.add_argument('--goal_desc_human_anns_prob', type=float, default=0,
                        help='Probability of using human-annotated goal language instead of templated goals (1.0 indicates all human annotations from ALFRED))')
    parser.add_argument('--seed', type=int, default=20240229,
                        help='Seed for the random generator when sampling human-annotated goal language instead of templated goals')

    args = parser.parse_args()

    args.save_path = args.save_path or args.data_path
    # If output is the same as data_path, we will overwrite the original game files.
    if args.save_path == args.data_path:
        print(colored("WARNING: You are overwriting the original game files.", "red"))
        if input("Do you want to continue? (y/n) ").lower() != "y":
            exit(0)

    gen_game_files(args)
