#!python

import os
import json
import glob
import random
import argparse
from os.path import join as pjoin

import textworld
from textworld.agents import HumanAgent

import textworld.gym

from alfworld.info import ALFWORLD_DATA
from alfworld.agents.utils.misc import add_task_to_grammar
from alfworld.agents.environment.alfred_tw_env import AlfredExpert, AlfredDemangler, AlfredExpertType


def main(args):
    print(f"Playing '{args.problem}'.")
    GAME_LOGIC = {
        "pddl_domain": open(args.domain).read(),
        "grammar": open(args.grammar).read(),
    }

    # load state and trajectory files
    pddl_file = os.path.join(args.problem, 'initial_state.pddl')
    json_file = os.path.join(args.problem, 'traj_data.json')
    with open(json_file, 'r') as f:
        traj_data = json.load(f)
    GAME_LOGIC['grammar'] = add_task_to_grammar(GAME_LOGIC['grammar'], traj_data)

    # dump game file
    gamedata = dict(**GAME_LOGIC, pddl_problem=open(pddl_file).read())
    gamefile = os.path.join(os.path.dirname(pddl_file), 'game.tw-pddl')
    json.dump(gamedata, open(gamefile, "w"))


    expert = AlfredExpert(expert_type=AlfredExpertType.PLANNER)
    # expert = AlfredExpert(expert_type=AlfredExpertType.HANDCODED)

    # register a new Gym environment.
    request_infos = textworld.EnvInfos(won=True, admissible_commands=True, score=True, max_score=True, intermediate_reward=True, extras=["expert_plan"])
    env_id = textworld.gym.register_game(gamefile, request_infos,
                                         max_episode_steps=1000000,
                                         wrappers=[AlfredDemangler(), expert])

    # reset env
    env = textworld.gym.make(env_id)
    obs, infos = env.reset()

    # human agent
    agent = HumanAgent(autocompletion=True, oracle=(expert.expert_type==AlfredExpertType.PLANNER))
    agent.reset(env)

    while True:
        print(obs)
        cmd = agent.act(infos, 0, False)

        if cmd == "ipdb":
            from ipdb import set_trace; set_trace()
            continue

        obs, score, done, infos = env.step(cmd)

        if done:
            print("You won!")
            break


if __name__ == "__main__":
    description = "Play the abstract text version of an ALFRED environment."
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument("problem", nargs="?", default=None,
                        help="Path to a folder containing PDDL and traj_data files."
                             f"Default: pick one at random found in {ALFWORLD_DATA}")
    parser.add_argument("--domain",
                        default=pjoin(ALFWORLD_DATA, "logic", "alfred.pddl"),
                        help="Path to a PDDL file describing the domain."
                             " Default: `%(default)s`.")
    parser.add_argument("--grammar",
                        default=pjoin(ALFWORLD_DATA, "logic", "alfred.twl2"),
                        help="Path to a TWL2 file defining the grammar used to generated text feedbacks."
                             " Default: `%(default)s`.")
    args = parser.parse_args()

    if args.problem is None:
        problems = glob.glob(pjoin(ALFWORLD_DATA, "**", "initial_state.pddl"), recursive=True)

        # Remove problem which contains movable receptacles.
        problems = [p for p in problems if "movable_recep" not in p]

        if len(problems) == 0:
            raise ValueError(f"Can't find problem files in {ALFWORLD_DATA}. Did you run alfworld-data?")

        args.problem = os.path.dirname(random.choice(problems))

    if "movable_recep" in args.problem:
        raise ValueError("This problem contains movable receptacles, which is not supported by ALFWorld.")

    main(args)
