
# Autogenerated by mlir-tblgen; don't manually edit.

from jaxlib.mlir.dialects._ods_common import _cext as _ods_cext
from jaxlib.mlir.dialects._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_result_or_op_results as _get_op_result_or_op_results,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir

import builtins
from typing import Sequence as _Sequence, Union as _Union


@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "mosaic_gpu"

@_ods_cext.register_operation(_Dialect)
class ArriveExpectTxOp(_ods_ir.OpView):
  OPERATION_NAME = "mosaic_gpu.arrive_expect_tx"

  _ODS_REGIONS = (0, True)

  def __init__(self, barrier, expect_tx, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(barrier)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["expect_tx"] = (expect_tx if (
    isinstance(expect_tx, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I32Attr')) else
      _ods_ir.AttrBuilder.get('I32Attr')(expect_tx, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def barrier(self):
    return self.operation.operands[0]

  @builtins.property
  def expect_tx(self):
    return self.operation.attributes["expect_tx"]

  @expect_tx.setter
  def expect_tx(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["expect_tx"] = value

def arrive_expect_tx(barrier, expect_tx, *, loc=None, ip=None) -> _ods_ir.Operation:
  return ArriveExpectTxOp(barrier=barrier, expect_tx=expect_tx, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class AsyncLoadOp(_ods_ir.OpView):
  OPERATION_NAME = "mosaic_gpu.async_load"

  _ODS_OPERAND_SEGMENTS = [1,1,1,-1,0,]

  _ODS_REGIONS = (0, True)

  def __init__(self, source, destination, barrier, indices, slice_lengths, collective, *, predicate=None, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(source)
    operands.append(destination)
    operands.append(barrier)
    operands.append(_get_op_results_or_values(indices))
    operands.append(predicate)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["slice_lengths"] = (slice_lengths if (
    isinstance(slice_lengths, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
      _ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(slice_lengths, context=_ods_context))
    attributes["collective"] = (collective if (
    isinstance(collective, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('anonymous_652')) else
      _ods_ir.AttrBuilder.get('anonymous_652')(collective, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def source(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 0)
    return operand_range[0]

  @builtins.property
  def destination(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 1)
    return operand_range[0]

  @builtins.property
  def barrier(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 2)
    return operand_range[0]

  @builtins.property
  def indices(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 3)
    return operand_range

  @builtins.property
  def predicate(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 4)
    return operand_range[0] if len(operand_range) > 0 else None

  @builtins.property
  def slice_lengths(self):
    return self.operation.attributes["slice_lengths"]

  @slice_lengths.setter
  def slice_lengths(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["slice_lengths"] = value

  @builtins.property
  def collective(self):
    return self.operation.attributes["collective"]

  @collective.setter
  def collective(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["collective"] = value

def async_load(source, destination, barrier, indices, slice_lengths, collective, *, predicate=None, loc=None, ip=None) -> _ods_ir.Operation:
  return AsyncLoadOp(source=source, destination=destination, barrier=barrier, indices=indices, slice_lengths=slice_lengths, collective=collective, predicate=predicate, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class AsyncStoreOp(_ods_ir.OpView):
  OPERATION_NAME = "mosaic_gpu.async_store"

  _ODS_OPERAND_SEGMENTS = [1,1,-1,0,]

  _ODS_REGIONS = (0, True)

  def __init__(self, source, destination, indices, slice_lengths, *, predicate=None, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(source)
    operands.append(destination)
    operands.append(_get_op_results_or_values(indices))
    operands.append(predicate)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["slice_lengths"] = (slice_lengths if (
    isinstance(slice_lengths, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
      _ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(slice_lengths, context=_ods_context))
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def source(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 0)
    return operand_range[0]

  @builtins.property
  def destination(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 1)
    return operand_range[0]

  @builtins.property
  def indices(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 2)
    return operand_range

  @builtins.property
  def predicate(self):
    operand_range = _ods_segmented_accessor(
         self.operation.operands,
         self.operation.attributes["operandSegmentSizes"], 3)
    return operand_range[0] if len(operand_range) > 0 else None

  @builtins.property
  def slice_lengths(self):
    return self.operation.attributes["slice_lengths"]

  @slice_lengths.setter
  def slice_lengths(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["slice_lengths"] = value

def async_store(source, destination, indices, slice_lengths, *, predicate=None, loc=None, ip=None) -> _ods_ir.Operation:
  return AsyncStoreOp(source=source, destination=destination, indices=indices, slice_lengths=slice_lengths, predicate=predicate, loc=loc, ip=ip)

@_ods_cext.register_operation(_Dialect)
class InitializeBarrierOp(_ods_ir.OpView):
  OPERATION_NAME = "mosaic_gpu.initialize_barrier"

  _ODS_REGIONS = (0, True)

  def __init__(self, barriers_ref, base_pointer, arrival_count, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(base_pointer)
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["arrival_count"] = (arrival_count if (
    isinstance(arrival_count, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('I64Attr')) else
      _ods_ir.AttrBuilder.get('I64Attr')(arrival_count, context=_ods_context))
    results.append(barriers_ref)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def base_pointer(self):
    return self.operation.operands[0]

  @builtins.property
  def arrival_count(self):
    return self.operation.attributes["arrival_count"]

  @arrival_count.setter
  def arrival_count(self, value):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["arrival_count"] = value

  @builtins.property
  def barriers_ref(self):
    return self.operation.results[0]

def initialize_barrier(barriers_ref, base_pointer, arrival_count, *, loc=None, ip=None) -> _ods_ir.Value:
  return InitializeBarrierOp(barriers_ref=barriers_ref, base_pointer=base_pointer, arrival_count=arrival_count, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class WGMMAOp(_ods_ir.OpView):
  OPERATION_NAME = "mosaic_gpu.wgmma"

  _ODS_REGIONS = (0, True)

  def __init__(self, accumulator, a, b, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(accumulator)
    operands.append(a)
    operands.append(b)
    _ods_context = _ods_get_default_loc_context(loc)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def accumulator(self):
    return self.operation.operands[0]

  @builtins.property
  def a(self):
    return self.operation.operands[1]

  @builtins.property
  def b(self):
    return self.operation.operands[2]

def wgmma(accumulator, a, b, *, loc=None, ip=None) -> _ods_ir.Value:
  return WGMMAOp(accumulator=accumulator, a=a, b=b, loc=loc, ip=ip).result

@_ods_cext.register_operation(_Dialect)
class WaitOp(_ods_ir.OpView):
  OPERATION_NAME = "mosaic_gpu.wait"

  _ODS_REGIONS = (0, True)

  def __init__(self, barrier, parity, *, loc=None, ip=None):
    operands = []
    results = []
    attributes = {}
    regions = None
    operands.append(barrier)
    operands.append(parity)
    _ods_context = _ods_get_default_loc_context(loc)
    _ods_successors = None
    super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, results=results, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)

  @builtins.property
  def barrier(self):
    return self.operation.operands[0]

  @builtins.property
  def parity(self):
    return self.operation.operands[1]

def wait(barrier, parity, *, loc=None, ip=None) -> _ods_ir.Operation:
  return WaitOp(barrier=barrier, parity=parity, loc=loc, ip=ip)
