#!/bin/bash

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

assignment="device"
parent_process=$(ps -p $PPID -o comm=)
if [ $parent_process  = "slurmstepd" ]; then
    local_id=$SLURM_LOCALID
elif [ $parent_process = "hydra_pmi_proxy" ]; then
    local_id=$MPI_LOCALRANKID
elif [ $parent_process = "palsd" ]; then
    local_id=$PALS_LOCAL_RANKID
else
    echo -e "\033[33mWARNING: Process not launched with a supported job launcher.\033[0m" >&2
fi

# See https://spec.oneapi.io/level-zero/latest/core/PROG.html#device-hierarchy regarding possible
# device hierarchy options.
if [ -z "$ZE_FLAT_DEVICE_HIERARCHY" -o "$ZE_FLAT_DEVICE_HIERARCHY" == "COMPOSITE" ]; then
    assignment="tile"
    echo -e "\033[33mWARNING: Assigning a single tile per PE due to ZE_FLAT_DEVICE_HIERARCHY=${ZE_FLAT_DEVICE_HIERARCHY}.\033[0m" >&2
fi

#Unset ZE_AFFINITY_MASK if previously set by user, so as not to impact results returned from sycl-ls
if [ -n "$ZE_AFFINITY_MASK" ]; then
    echo -e "\033[33mWARNING: Previous assignment of ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK} being unset.\033[0m" >&2
    unset ZE_AFFINITY_MASK
fi

#Unset SYCL_DEVICE_FILTER if previously set by user, so as not to impact results returned from sycl-ls
if [ -n "$SYCL_DEVICE_FILTER" ]; then
    echo -e "\033[33mWARNING: Previous assignment of SYCL_DEVICE_FILTER=${SYCL_DEVICE_FILTER} being unset.\033[0m" >&2
    unset SYCL_DEVICE_FILTER
fi

#Unset ONEAPI_DEVICE_SELECTOR if previously set by user
if [ -n "$ONEAPI_DEVICE_SELECTOR" ]; then
    echo -e "\033[33mWARNING: Previous assignment of ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR} being unset.\033[0m" >&2
    unset ONEAPI_DEVICE_SELECTOR
fi

#Determine which utilities are available for device/sub-device detection
which clinfo > /dev/null 2>&1
which_clinfo_exit_code=$?
which sycl-ls > /dev/null 2>&1
which_sycl_ls_exit_code=$?

#Identify sub-devices available for use
all_sub_devices=()
if [ $which_sycl_ls_exit_code -eq 0 ]; then
    sycl_ls_output=$(ONEAPI_DEVICE_SELECTOR=level_zero:* sycl-ls 2>/dev/null)
    num_root_devices=$(grep -P -o "\[level_zero:gpu\]" <<< ${sycl_ls_output} | wc -l)
    if [ ${num_root_devices} -gt 0 ]; then
        root_devices=($(seq 0 $((${num_root_devices} - 1))))
    else
        root_devices=($(grep -P -o "(?<=gpu:)[0-9]+" <<< ${sycl_ls_output}))
    fi
    if [ "$assignment" == "tile" ]; then
        for root_device in ${root_devices[@]}
        do
            sycl_ls_output=$(ONEAPI_DEVICE_SELECTOR=level_zero:${root_device}.* sycl-ls 2>/dev/null)
            num_sub_devices=$(grep -P -o "level_zero:gpu" <<< ${sycl_ls_output} | wc -l)
            if [ ${num_sub_devices} -gt 0 ]; then
                for (( sub_device=0; sub_device<${num_sub_devices}; sub_device++ )); do
                    all_sub_devices+=("${root_device}.${sub_device}")
                done
            else
                #Root device is not partitionable / has no sub-devices
                all_sub_devices+=("${root_device}")
            fi
        done
    fi
elif [ $which_clinfo_exit_code -eq 0 ]; then
    root_devices=( $(seq 0 $(( $(clinfo | grep -P -o "(?<=Device Type)[ \t]*GPU" | wc -l) - 1 )) ) )
    if [ "$assignment" == "tile" ]; then
        for root_device in "${root_devices[@]}"; do
                num_sub_devices=$(ZE_AFFINITY_MASK=$root_device clinfo | grep -P -A 10 "(?<=Device Type)[ \t]*GPU" | grep -P -o "(?<=Max number of sub-devices)[ \t]*[0-9]+" | xargs)
                if [ $num_sub_devices -eq 0 ]; then
                    #Root device is not partitionable / has no sub-devices
                    all_sub_devices+=("${root_device}")
                else
                    for (( sub_device=0; sub_device<${num_sub_devices}; sub_device++ )); do
                            all_sub_devices+=("${root_device}.${sub_device}")
                    done
                fi
        done
    fi
else
    echo -e "\033[33mWARNING: Could not detect devices on system. Ensure either clinfo or sycl-ls is available.\033[0m" >&2
fi

#Set visibility of devices for Level Zero
if [ "$assignment" == "tile" ]; then
    export ZE_AFFINITY_MASK=${all_sub_devices[${local_id}]}
else
    export ZE_AFFINITY_MASK=${root_devices[${local_id}]}
fi

#Set visibility of devices for SYCL
#From SYCL's perspective, device 0 (level_zero:0) will now correspond to the device previously assigned to ZE_AFFINITY_MASK
export ONEAPI_DEVICE_SELECTOR="level_zero:0"

# Invoke the main program
NEOReadDebugKeys=1 UseKmdMigration=1 numactl --cpunodebind=all $*
