# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Shim for systems that need to load both SessionBundle and SavedModel.

This is intended to be used during migration to SavedModel.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.contrib.session_bundle import constants as legacy_constants
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.contrib.session_bundle import session_bundle
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import meta_graph
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants


def _add_input_to_signature_def(tensor_name, map_key, signature_def):
  """Add input tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def inputs
    map_key: string key to key into signature_def inputs map
    signature_def: object of type meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def inputs map keyed with
    map_key
  """
  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.inputs[map_key].CopyFrom(tensor_info)


def _add_output_to_signature_def(tensor_name, map_key, signature_def):
  """Add output tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def outputs
    map_key: string key to key into signature_def outputs map
    signature_def: object of type meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def outputs map keyed with
    map_key
  """

  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.outputs[map_key].CopyFrom(tensor_info)


def _convert_default_signature_to_signature_def(signatures):
  """Convert default signature to object of type SignatureDef.

  Args:
    signatures: object of type manifest_pb2.Signatures()

  Returns:
    object of type SignatureDef which contains a converted version of default
    signature from input signatures object

    Returns None if signature is of generic type because it cannot be converted
    to SignatureDef.

  """
  default_signature = signatures.default_signature
  signature_def = meta_graph_pb2.SignatureDef()
  if default_signature.WhichOneof("type") == "regression_signature":
    regression_signature = default_signature.regression_signature
    signature_def.method_name = signature_constants.REGRESS_METHOD_NAME
    _add_input_to_signature_def(regression_signature.input.tensor_name,
                                signature_constants.REGRESS_INPUTS,
                                signature_def)
    _add_output_to_signature_def(regression_signature.output.tensor_name,
                                 signature_constants.REGRESS_OUTPUTS,
                                 signature_def)
  elif default_signature.WhichOneof("type") == "classification_signature":
    classification_signature = default_signature.classification_signature
    signature_def.method_name = signature_constants.CLASSIFY_METHOD_NAME
    _add_input_to_signature_def(classification_signature.input.tensor_name,
                                signature_constants.CLASSIFY_INPUTS,
                                signature_def)
    _add_output_to_signature_def(classification_signature.classes.tensor_name,
                                 signature_constants.CLASSIFY_OUTPUT_CLASSES,
                                 signature_def)
    _add_output_to_signature_def(classification_signature.scores.tensor_name,
                                 signature_constants.CLASSIFY_OUTPUT_SCORES,
                                 signature_def)
  else:
    logging.error("Only classification and regression default signatures "
                  "are supported for up-conversion. %s is not "
                  "supported" % default_signature.WhichOneof("type"))
    return None
  return signature_def


def _convert_named_signatures_to_signature_def(signatures):
  """Convert named signatures to object of type SignatureDef.

  Args:
    signatures: object of type manifest_pb2.Signatures()

  Returns:
    object of type SignatureDef which contains a converted version of named
    signatures from input signatures object

  Raises:
    RuntimeError: if input and output named signatures are not of type
    GenericSignature
  """
  signature_def = meta_graph_pb2.SignatureDef()
  input_signature = signatures.named_signatures[
      signature_constants.PREDICT_INPUTS]
  output_signature = signatures.named_signatures[
      signature_constants.PREDICT_OUTPUTS]
  # TODO(pdudnik): what if there are other signatures? Mimic cr/140900781 once
  # it is submitted.
  if (input_signature.WhichOneof("type") != "generic_signature" or
      output_signature.WhichOneof("type") != "generic_signature"):
    raise RuntimeError("Named input and output signatures can only be "
                       "up-converted if they are generic signature. "
                       "Input signature type is %s, output signature type is "
                       "%s" % (input_signature.WhichOneof("type"),
                               output_signature.WhichOneof("type")))

  signature_def.method_name = signature_constants.PREDICT_METHOD_NAME
  for key, val in input_signature.generic_signature.map.items():
    _add_input_to_signature_def(val.tensor_name, key, signature_def)
  for key, val in output_signature.generic_signature.map.items():
    _add_output_to_signature_def(val.tensor_name, key, signature_def)
  return signature_def


def _convert_signatures_to_signature_defs(metagraph_def):
  """Produce default and named upconverted SignatureDef objects from Signatures.

  Args:
    metagraph_def: object of type meta_graph_pb2.MetaGraphDef containing legacy
    format Session Bundle signatures

  Returns:
    default_signature_def: object of type SignatureDef which contains an
        upconverted version of default signatures in metagraph_def
    named_signature_def: object of type SignatureDef which contains an
        upconverted version of named signatures in metagraph_def
  """

  collection_def = metagraph_def.collection_def
  signatures_proto = manifest_pb2.Signatures()
  signatures = collection_def[legacy_constants.SIGNATURES_KEY].any_list.value[0]
  signatures.Unpack(signatures_proto)

  default_signature_def = None
  named_signature_def = None
  if signatures_proto.HasField("default_signature"):
    default_signature_def = _convert_default_signature_to_signature_def(
        signatures_proto)
  if len(signatures_proto.named_signatures) > 1:
    named_signature_def = _convert_named_signatures_to_signature_def(
        signatures_proto)
  return default_signature_def, named_signature_def


def _load_saved_model_from_session_bundle_path(export_dir, target, config):
  """Load legacy TF Exporter/SessionBundle checkpoint.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
    tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    metagraph_def: The `MetaGraphDef` protocol buffer loaded in the provided
    session. This can be used to further extract signature-defs,
    collection-defs, etc.
    This model is up-converted to SavedModel format. Specifically, metagraph_def
    SignatureDef field is populated with Signatures converted from legacy
    signatures contained within CollectionDef

  Raises:
    RuntimeError: If metagraph already contains signature_def and cannot be
    up-converted.
  """

  meta_graph_filename = os.path.join(export_dir,
                                     legacy_constants.META_GRAPH_DEF_FILENAME)

  metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
  if metagraph_def.signature_def:
    raise RuntimeError("Legacy graph contains signature def, unable to "
                       "up-convert.")

  # Add SignatureDef to metagraph.
  default_signature_def, named_signature_def = (
      _convert_signatures_to_signature_defs(metagraph_def))
  if default_signature_def:
    metagraph_def.signature_def[
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].CopyFrom(
            default_signature_def)
  if named_signature_def:
    signature_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    if default_signature_def:
      signature_def_key += "_from_named"
    metagraph_def.signature_def[signature_def_key].CopyFrom(named_signature_def)

  # We cannot just output session we loaded with older metagraph_def and
  # up-converted metagraph definition because Session has an internal object of
  # type Graph which is populated from meta_graph_def. If we do not create
  # session with our new meta_graph_def, then Graph will be out of sync with
  # meta_graph_def.
  sess, metagraph_def = session_bundle.load_session_bundle_from_path(
      export_dir, target, config, meta_graph_def=metagraph_def)
  return sess, metagraph_def


def load_session_bundle_or_saved_model_bundle_from_path(export_dir,
                                                        tags=None,
                                                        target="",
                                                        config=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    tags: Set of string tags to identify the required MetaGraphDef when model is
          saved as SavedModel. These should correspond to the tags used when
          saving the variables using the SavedModel `save()` API.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
            tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  metagraph_def = None
  sess = None
  if loader.maybe_saved_model_directory(export_dir):
    sess = session.Session(target, graph=None, config=config)
    metagraph_def = loader.load(sess, tags, export_dir)
  elif session_bundle.maybe_session_bundle_dir(export_dir):
    sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir,
                                                                     target,
                                                                     config)
  else:
    raise RuntimeError("SessionBundle or SavedModelBundle not found at "
                       "specified export location: %s" % export_dir)

  return sess, metagraph_def
