# Copyright 2015 datawire. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
from .dispatch import dispatch
from .ast import *

DEFAULT = object()

@dispatch(Class, Field)
def get_field(cls, field, default=DEFAULT):
    return get_field(cls, field.name, default)

@dispatch(Class, Name)
def get_field(cls, name, default=DEFAULT):
    return get_field(cls, name.text, default)

@dispatch(Class, basestring)
def get_field(cls, name, default=DEFAULT):
    if name in cls.env:
        return cls.env[name]
    else:
        for type in cls.bases:
            if type.resolved.type:
                result = get_field(type.resolved.type, name, None)
                if result: return result
        else:
            if default is DEFAULT:
                raise KeyError(name)
            else:
                return default

def get_fields_r(cls, result):
    for b in cls.bases:
        get_fields_r(b.resolved.type, result)
    for d in cls.definitions:
        if isinstance(d, Field):
            result.append(d)
    return result

def get_fields(cls):
    result = []
    get_fields_r(cls, result)
    return result

def has_super(fun):
    for stmt in fun.body.statements:
        if is_super(stmt):
            return True
    return False

@dispatch(Statement)
def is_super(stmt):
    return False

@dispatch(ExprStmt)
def is_super(stmt):
    return is_super(stmt.expr)

@dispatch(Expression)
def is_super(expr):
    return False

@dispatch(Call)
def is_super(call):
    return isinstance(call.expr, Super)

def constructor(cls):
    cons = constructors(cls)
    if not cons:
        cons = base_constructors(cls)
    if cons:
        assert len(cons) == 1
        return cons[0]

def constructors(cls):
    return [d for d in cls.definitions if isinstance(d, Callable) and d.type is None]

def base_type(cls):
    for b in cls.bases:
        if is_extendable(b.resolved.type):
            return b

def base_constructors(cls):
    base = base_type(cls)
    cons = []
    while base:
        cons = constructors(base.resolved.type)
        if cons:
            break
        else:
            base = base_type(base.resolved.type)
    return cons

@dispatch(String)
def literal_to_str(lit):
    return str(lit)[1:-1]

@dispatch(Number)
def literal_to_str(lit):
    return str(lit)

def get_package_version(pkg):
    for ann in pkg.annotations:
        if ann.name.text == "version":
            assert len(ann.arguments) == 1
            return literal_to_str(ann.arguments[0])
    return "0.0"

def namever(packages):
    if packages:
        firstPackageName, firstPackageList = packages.items()[0]
        version = firstPackageList[0].version
    else:
        firstPackageName = ""
        version = "0.0"
    return firstPackageName, version

def is_extendable(node):
    return isinstance(node.resolved.type, Class) and \
        not isinstance(node.resolved.type, (Primitive, Interface))

@dispatch(Class, dict)
def get_methods(cls, result, predicate):
    for dfn in cls.definitions:
        if isinstance(dfn, Callable) and dfn.type and predicate(dfn):
            name = dfn.name.text
            if name not in result:
                result[name] = dfn

@dispatch(Class)
def get_methods(cls, predicate):
    result = OrderedDict()
    get_methods(cls, result, predicate)
    return result

@dispatch(Class)
def get_methods(cls):
    return get_methods(cls, lambda x: True)

@dispatch(Class, dict, dict)
def get_defaulted_methods(cls, result, derived):
    if isinstance(cls, (Interface, Primitive)):
        get_methods(cls, result, lambda dfn: dfn.body and dfn.name.text not in derived)
    for base in cls.bases:
        get_defaulted_methods(base.resolved.type, result, derived)

@dispatch(Class)
def get_defaulted_methods(cls):
    result = OrderedDict()
    derived = get_methods(cls)
    get_defaulted_methods(cls, result, derived)
    return result
