
# Autogenerated by mlir-tblgen; don't manually edit.

from ._ods_common import _cext as _ods_cext
from ._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_result_or_op_results as _get_op_result_or_op_results,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir

import builtins
from typing import Sequence as _Sequence, Union as _Union


@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "sdy"

@_ods_cext.register_operation(_Dialect)
class AllGatherOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.all_gather"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, gathering_axes, out_sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["gathering_axes"] = (gathering_axes if (
    isinstance(gathering_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ListOfAxisRefLists')) else
      _ods_ir.AttrBuilder.get('Sdy_ListOfAxisRefLists')(gathering_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self):
    return self.operation.operands[0]

  @builtins.property
  def gathering_axes(self):
    return self.operation.attributes["gathering_axes"]

  @gathering_axes.setter
  def gathering_axes(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["gathering_axes"] = value

  @builtins.property
  def out_sharding(self):
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def all_gather(tensor, gathering_axes, out_sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return AllGatherOp(tensor=tensor, gathering_axes=gathering_axes, out_sharding=out_sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class AllReduceOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.all_reduce"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, reduction_axes, out_sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["reduction_axes"] = (reduction_axes if (
    isinstance(reduction_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_AxisRefList')) else
      _ods_ir.AttrBuilder.get('Sdy_AxisRefList')(reduction_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self):
    return self.operation.operands[0]

  @builtins.property
  def reduction_axes(self):
    return self.operation.attributes["reduction_axes"]

  @reduction_axes.setter
  def reduction_axes(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["reduction_axes"] = value

  @builtins.property
  def out_sharding(self):
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def all_reduce(tensor, reduction_axes, out_sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return AllReduceOp(tensor=tensor, reduction_axes=reduction_axes, out_sharding=out_sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class AllSliceOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.all_slice"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, slicing_axes, out_sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["slicing_axes"] = (slicing_axes if (
    isinstance(slicing_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ListOfAxisRefLists')) else
      _ods_ir.AttrBuilder.get('Sdy_ListOfAxisRefLists')(slicing_axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self):
    return self.operation.operands[0]

  @builtins.property
  def slicing_axes(self):
    return self.operation.attributes["slicing_axes"]

  @slicing_axes.setter
  def slicing_axes(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["slicing_axes"] = value

  @builtins.property
  def out_sharding(self):
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def all_slice(tensor, slicing_axes, out_sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return AllSliceOp(tensor=tensor, slicing_axes=slicing_axes, out_sharding=out_sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class AllToAllOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.all_to_all"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, src_dim, tgt_dim, axes, out_sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["src_dim"] = (src_dim if (
    isinstance(src_dim, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I64Attr')) else
      _ods_ir.AttrBuilder.get('I64Attr')(src_dim, context=_ods_context))
    attributes["tgt_dim"] = (tgt_dim if (
    isinstance(tgt_dim, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I64Attr')) else
      _ods_ir.AttrBuilder.get('I64Attr')(tgt_dim, context=_ods_context))
    attributes["axes"] = (axes if (
    isinstance(axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_AxisRefList')) else
      _ods_ir.AttrBuilder.get('Sdy_AxisRefList')(axes, context=_ods_context))
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self):
    return self.operation.operands[0]

  @builtins.property
  def src_dim(self):
    return self.operation.attributes["src_dim"]

  @src_dim.setter
  def src_dim(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["src_dim"] = value

  @builtins.property
  def tgt_dim(self):
    return self.operation.attributes["tgt_dim"]

  @tgt_dim.setter
  def tgt_dim(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["tgt_dim"] = value

  @builtins.property
  def axes(self):
    return self.operation.attributes["axes"]

  @axes.setter
  def axes(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["axes"] = value

  @builtins.property
  def out_sharding(self):
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def all_to_all(tensor, src_dim, tgt_dim, axes, out_sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return AllToAllOp(tensor=tensor, src_dim=src_dim, tgt_dim=tgt_dim, axes=axes, out_sharding=out_sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class CollectivePermuteOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.collective_permute"

  _ODS_REGIONS = (0, True)

  def __init__(self, tensor, out_sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(tensor)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["out_sharding"] = (out_sharding if (
    isinstance(out_sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(out_sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensor(self):
    return self.operation.operands[0]

  @builtins.property
  def out_sharding(self):
    return self.operation.attributes["out_sharding"]

  @out_sharding.setter
  def out_sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def collective_permute(tensor, out_sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return CollectivePermuteOp(tensor=tensor, out_sharding=out_sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ConstantOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.constant"

  _ODS_REGIONS = (0, True)

  def __init__(self, value, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["value"] = (value if (
    isinstance(value, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('ElementsAttr')) else
      _ods_ir.AttrBuilder.get('ElementsAttr')(value, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def value(self):
    return self.operation.attributes["value"]

  @value.setter
  def value(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["value"] = value

  @builtins.property
  def output(self):
    return self.operation.results[0]

def constant(value, *, loc=None, ip=None) -> _ods_ir.Value:
  return ConstantOp(value=value, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class DataFlowEdgeOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.data_flow_edge"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, *, sharding=None, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    if sharding is not None: attributes["sharding"] = (sharding if (
        isinstance(sharding, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def sharding(self):
    if "sharding" not in self.operation.attributes:
      return None
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value):
    if value is not None:
      self.operation.attributes["sharding"] = value
    elif "sharding" in self.operation.attributes:
      del self.operation.attributes["sharding"]

  @sharding.deleter
  def sharding(self):
    del self.operation.attributes["sharding"]

  @builtins.property
  def result(self):
    return self.operation.results[0]

def data_flow_edge(input, *, sharding=None, loc=None, ip=None) -> _ods_ir.Value:
  return DataFlowEdgeOp(input=input, sharding=sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ManualComputationOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.manual_computation"

  _ODS_REGIONS = (1, True)

  def __init__(self, results_, tensors, in_shardings, out_shardings, manual_axes, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(tensors))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["in_shardings"] = (in_shardings if (
    isinstance(in_shardings, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
    attributes["out_shardings"] = (out_shardings if (
    isinstance(out_shardings, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
    attributes["manual_axes"] = (manual_axes if (
    isinstance(manual_axes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_ManualAxes')) else
      _ods_ir.AttrBuilder.get('Sdy_ManualAxes')(manual_axes, context=_ods_context))
    results.extend(results_)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def tensors(self):
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def in_shardings(self):
    return self.operation.attributes["in_shardings"]

  @in_shardings.setter
  def in_shardings(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["in_shardings"] = value

  @builtins.property
  def out_shardings(self):
    return self.operation.attributes["out_shardings"]

  @out_shardings.setter
  def out_shardings(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["out_shardings"] = value

  @builtins.property
  def manual_axes(self):
    return self.operation.attributes["manual_axes"]

  @manual_axes.setter
  def manual_axes(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["manual_axes"] = value

  @builtins.property
  def results_(self):
    _ods_variadic_group_length = len(self.operation.results) - 1 + 1
    return self.operation.results[0:0 + _ods_variadic_group_length]

  @builtins.property
  def body(self):
    return self.regions[0]

def manual_computation(results_, tensors, in_shardings, out_shardings, manual_axes, *, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(ManualComputationOp(results_=results_, tensors=tensors, in_shardings=in_shardings, out_shardings=out_shardings, manual_axes=manual_axes, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class MeshOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.mesh"

  _ODS_REGIONS = (0, True)

  def __init__(self, sym_name, mesh, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sym_name"] = (sym_name if (
    isinstance(sym_name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('SymbolNameAttr')) else
      _ods_ir.AttrBuilder.get('SymbolNameAttr')(sym_name, context=_ods_context))
    attributes["mesh"] = (mesh if (
    isinstance(mesh, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_Mesh')) else
      _ods_ir.AttrBuilder.get('Sdy_Mesh')(mesh, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def sym_name(self):
    return self.operation.attributes["sym_name"]

  @sym_name.setter
  def sym_name(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sym_name"] = value

  @builtins.property
  def mesh(self):
    return self.operation.attributes["mesh"]

  @mesh.setter
  def mesh(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["mesh"] = value

def mesh(sym_name, mesh, *, loc=None, ip=None) -> _ods_ir.Operation:
  return MeshOp(sym_name=sym_name, mesh=mesh, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class NamedComputationOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.named_computation"

  _ODS_REGIONS = (1, True)

  def __init__(self, result, name, operands_, *, in_shardings=None, out_shardings=None, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(operands_))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["name"] = (name if (
    isinstance(name, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('StrAttr')) else
      _ods_ir.AttrBuilder.get('StrAttr')(name, context=_ods_context))
    if in_shardings is not None: attributes["in_shardings"] = (in_shardings if (
        isinstance(in_shardings, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(in_shardings, context=_ods_context))
    if out_shardings is not None: attributes["out_shardings"] = (out_shardings if (
        isinstance(out_shardings, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('Sdy_TensorShardingPerValue')) else
          _ods_ir.AttrBuilder.get('Sdy_TensorShardingPerValue')(out_shardings, context=_ods_context))
    results.extend(result)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def operands_(self):
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

  @builtins.property
  def name(self):
    return self.operation.attributes["name"]

  @name.setter
  def name(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["name"] = value

  @builtins.property
  def in_shardings(self):
    if "in_shardings" not in self.operation.attributes:
      return None
    return self.operation.attributes["in_shardings"]

  @in_shardings.setter
  def in_shardings(self, value):
    if value is not None:
      self.operation.attributes["in_shardings"] = value
    elif "in_shardings" in self.operation.attributes:
      del self.operation.attributes["in_shardings"]

  @in_shardings.deleter
  def in_shardings(self):
    del self.operation.attributes["in_shardings"]

  @builtins.property
  def out_shardings(self):
    if "out_shardings" not in self.operation.attributes:
      return None
    return self.operation.attributes["out_shardings"]

  @out_shardings.setter
  def out_shardings(self, value):
    if value is not None:
      self.operation.attributes["out_shardings"] = value
    elif "out_shardings" in self.operation.attributes:
      del self.operation.attributes["out_shardings"]

  @out_shardings.deleter
  def out_shardings(self):
    del self.operation.attributes["out_shardings"]

  @builtins.property
  def body(self):
    return self.regions[0]

def named_computation(result, name, operands_, *, in_shardings=None, out_shardings=None, loc=None, ip=None) -> _ods_ir.Value:
  return _get_op_result_or_op_results(NamedComputationOp(result=result, name=name, operands_=operands_, in_shardings=in_shardings, out_shardings=out_shardings, loc=loc, ip=ip))

@_ods_cext.register_operation(_Dialect)
class PropagationBarrierOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.propagation_barrier"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, allowed_direction, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["allowed_direction"] = (allowed_direction if (
    isinstance(allowed_direction, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_PropagationDirection')) else
      _ods_ir.AttrBuilder.get('Sdy_PropagationDirection')(allowed_direction, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def allowed_direction(self):
    return self.operation.attributes["allowed_direction"]

  @allowed_direction.setter
  def allowed_direction(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["allowed_direction"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def propagation_barrier(input, allowed_direction, *, loc=None, ip=None) -> _ods_ir.Value:
  return PropagationBarrierOp(input=input, allowed_direction=allowed_direction, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReshardOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.reshard"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sharding"] = (sharding if (
    isinstance(sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def sharding(self):
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def reshard(input, sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return ReshardOp(input=input, sharding=sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ReturnOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.return"

  _ODS_REGIONS = (0, True)

  def __init__(self, results_, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.extend(_get_op_results_or_values(results_))
    _ods_context = _ods_get_default_loc_context(loc)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def results_(self):
    _ods_variadic_group_length = len(self.operation.operands) - 1 + 1
    return self.operation.operands[0:0 + _ods_variadic_group_length]

def return_(results_, *, loc=None, ip=None) -> _ods_ir.Operation:
  return ReturnOp(results_=results_, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class ShardingConstraintOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.sharding_constraint"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, sharding, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["sharding"] = (sharding if (
    isinstance(sharding, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('Sdy_TensorSharding')) else
      _ods_ir.AttrBuilder.get('Sdy_TensorSharding')(sharding, context=_ods_context))
    results.extend([operands[0].type] * 1)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def sharding(self):
    return self.operation.attributes["sharding"]

  @sharding.setter
  def sharding(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["sharding"] = value

  @builtins.property
  def result(self):
    return self.operation.results[0]

def sharding_constraint(input, sharding, *, loc=None, ip=None) -> _ods_ir.Value:
  return ShardingConstraintOp(input=input, sharding=sharding, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class ShardingGroupOp(_ods_ir.OpView):
  OPERATION_NAME = "sdy.sharding_group"

  _ODS_REGIONS = (0, True)

  def __init__(self, input, group_id, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(input)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["group_id"] = (group_id if (
    isinstance(group_id, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I64Attr')) else
      _ods_ir.AttrBuilder.get('I64Attr')(group_id, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def input(self):
    return self.operation.operands[0]

  @builtins.property
  def group_id(self):
    return self.operation.attributes["group_id"]

  @group_id.setter
  def group_id(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["group_id"] = value

def sharding_group(input, group_id, *, loc=None, ip=None) -> _ods_ir.Operation:
  return ShardingGroupOp(input=input, group_id=group_id, loc=loc, ip=ip)
