# Copyright 2021 The JAX Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import inspect
from typing import Any, Callable, Sequence, Iterable, Tuple

from . import pytree

_AvalDimSharding = Any
_MeshDimAssignment = Any

class NoSharding:
  def __init__(self) -> None: ...
  def __repr__(self) -> str: ...
  def __eq__(self, __other: Any) -> bool: ...

class Chunked:
  @property
  def chunks(self) -> Sequence[int]: ...
  def __init__(self, __chunks: Sequence[int]) -> None: ...
  def __repr__(self) -> str: ...
  def __eq__(self, __other: Any) -> bool: ...

class Unstacked:
  @property
  def size(self) -> int: ...
  def __init__(self, __sz: int) -> None: ...
  def __repr__(self) -> str: ...
  def __eq__(self, __other: Any) -> bool: ...

class ShardedAxis:
  @property
  def axis(self) -> int: ...
  def __init__(self, __axis: int) -> None: ...
  def __repr__(self) -> str: ...
  def __eq__(self, __other: ShardedAxis) -> bool: ...

class Replicated:
  @property
  def replicas(self) -> int: ...
  def __init__(self, __replicas: int) -> None: ...
  def __repr__(self) -> str: ...
  def __eq__(self, __other: Replicated) -> bool: ...

class ShardingSpec:
  def __init__(self,
               sharding: Iterable[_AvalDimSharding],
               mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ...
  @property
  def sharding(self) -> Tuple[_AvalDimSharding, ...]: ...
  @property
  def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ...
  def __eq__(self, __other: ShardingSpec) -> bool: ...
  def __hash__(self) -> int: ...

  _HAS_DYNAMIC_ATTRIBUTES = True

class PmapFunction:
  def __call__(self, *args, **kwargs) -> Any: ...
  def __getstate__(self) -> Any: ...
  def __setstate__(self, Any): ...
  __signature__: inspect.Signature
  def _cache_size(self) -> int: ...
  def _cache_clear(self) -> None: ...
  def _debug_cache_keys(self) -> str: ...

def pmap(fun: Callable[..., Any],
         cache_miss: Callable[..., Any],
         static_argnums: Sequence[int],
         shard_arg_fallback: Callable[..., Any],
         pytree_registry: pytree.PyTreeRegistry) -> PmapFunction: ...
