// Copyright © 2023-2025 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT

// clang-format off
#include "shim.debug_simulate_encoded_softmax.h"
#include <aotriton/util.h>
#include <tuple>
#include <iostream>
#include "iface.op_attn_fwd.h"

namespace AOTRITON_NS::v3::flash {

#if 1
using AOTRITON_NS::v3::flash::OpAttnFwdParams;
#endif

#define CAST(x) const_cast<void*>(static_cast<const void*>(x))
typedef std::vector<void*>(*PP_FUNC)(const OpAttnFwdParams& context, const TritonAuxiliaryArguments&);

namespace {
extern PP_FUNC prepare_arguments[ 1 ];
}

int64_t DebugSimulateEncodedSoftmaxContext::godel_number() const
{
    int64_t sum = 0;
    const auto& args = *params;
    {
        int64_t number = -1;
        if (args.encoded_softmax->dtype() == DType::kFloat16) number = 0 ;
        if (args.encoded_softmax->dtype() == DType::kBFloat16) number = 1 ;
        if (args.encoded_softmax->dtype() == DType::kFloat32) number = 2 ;
        if (number < 0) {
#ifndef NDEBUG
            std::cerr << __FILE__ << ":" << __LINE__ << ": Unsupported encoded_softmax, value: " << args.encoded_softmax->dtype() << std::endl;
#endif
            return -1;
        }
        sum += number * 1;
    }

    return sum;
}

hipError_t
DebugSimulateEncodedSoftmaxContext::lookup_optimal(Gpu gpu) {
    auto [arch_number, mod_number] = get_archmod_number(gpu);
    if (arch_number < 0) {
        return hipErrorNoBinaryForGpu;
    }
    kernel_on_device = nullptr;
    auto number = godel_number();
    if (number < 0)
        return hipErrorNotSupported;
    auto tune_func = autotune_table[arch_number][number];
    if (!tune_func)
        return hipErrorProfilerNotInitialized;
    tune_func(*this, mod_number);
    if (!kernel_on_device)
        return hipErrorSharedObjectSymbolNotFound;
    return hipSuccess;
}

hipError_t
DebugSimulateEncodedSoftmaxContext::launch(hipStream_t stream) const {
    constexpr std::string_view triton_kernel_name { "debug_simulate_encoded_softmax" };
    TritonAuxiliaryArguments aux;
    auto args = prepare_arguments[pp_args_index](*this->params, aux);
    dim3 grid;
    if (custom_grid_calculator) {
        grid = custom_grid_calculator(*this);
    } else {
        grid = grid_calculator();
    }
#if AOTRITON_BUILD_FOR_TUNING
    return kernel_on_device->invoke(triton_kernel_name,
                                    package_path,
                                    func_name,
                                    arch_name,
                                    grid,
                                    args,
                                    peek_kernel_image,
                                    stream);
#else
    return kernel_on_device->invoke(triton_kernel_name,
                                    package_path,
                                    func_name,
                                    arch_name,
                                    grid,
                                    args,
                                    stream);
#endif
}

std::tuple<int, int>
DebugSimulateEncodedSoftmaxContext::get_archmod_number(Gpu gpu) {
    if (gpu == GPU_AMD_ARCH_GFX950_MOD0) return { 0, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1100_MOD0) return { 1, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1101_MOD0) return { 2, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1102_MOD0) return { 3, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1151_MOD0) return { 4, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1150_MOD0) return { 5, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1201_MOD0) return { 6, 0 };
    if (gpu == GPU_AMD_ARCH_GFX1200_MOD0) return { 7, 0 };
    // TODO: print warning about tuning for this GPU mod is not built.
    // Note: if some mod does not have tuning info in the database at all, the
    //       getGpuFromStream should not return that mod from beginning.
    return std::make_tuple(-1, 0);
}


static std::vector<void*>
debug_simulate_encoded_softmax_pp_args_0(const OpAttnFwdParams& params,
                                         const TritonAuxiliaryArguments& aux) {
  return { params.encoded_softmax->kparam_data_ptr(), // encoded_softmax
           params.encoded_softmax->kparam_stride(0), // stride_rz
           params.encoded_softmax->kparam_stride(1), // stride_rh
           params.encoded_softmax->kparam_stride(2), // stride_rm
           CAST(&params.dropout_p), // dropout_p
           CAST(&params.Num_head_q), // Num_head_q
           CAST(&params.Max_seqlen_q), // Max_seqlen_q
           CAST(&params.Max_seqlen_k), // Max_seqlen_k
           params.philox_seed_ptr->kparam_data_ptr(), // philox_seed_ptr
           params.philox_offset1->kparam_data_ptr(), // philox_offset1
           CAST(&params.philox_offset2), // philox_offset2
           CAST(&aux.global_scratch),
           CAST(&aux.profile_scratch)
         };
}

namespace {
PP_FUNC prepare_arguments[ 1 ] = {
  debug_simulate_encoded_softmax_pp_args_0
};
}


const std::vector<std::string>& DebugSimulateEncodedSoftmaxMetadata::get_encoded_softmax_choices()
{
    static const std::vector<std::string> choices = { "*fp16:16", "*bf16:16", "*fp32:16" };
    return choices;
}

const std::vector<std::string>& DebugSimulateEncodedSoftmaxMetadata::get_dropout_p_choices()
{
    static const std::vector<std::string> choices = { "fp32" };
    return choices;
}

const std::vector<std::string>& DebugSimulateEncodedSoftmaxMetadata::get_Num_head_q_choices()
{
    static const std::vector<std::string> choices = { "i32" };
    return choices;
}

const std::vector<std::string>& DebugSimulateEncodedSoftmaxMetadata::get_philox_seed_ptr_choices()
{
    static const std::vector<std::string> choices = { "*u64" };
    return choices;
}

const std::vector<std::string>& DebugSimulateEncodedSoftmaxMetadata::get_philox_offset1_choices()
{
    static const std::vector<std::string> choices = { "*u64" };
    return choices;
}

const std::vector<std::string>& DebugSimulateEncodedSoftmaxMetadata::get_philox_offset2_choices()
{
    static const std::vector<std::string> choices = { "u64" };
    return choices;
}

namespace autotune {

const char debug_simulate_encoded_softmax_packed_string[] =
"64_32\0"
"wave2_warp4_stg1\0";

int debug_simulate_encoded_softmax__lut_lambda__0 (const OpAttnFwdParams& params, int mod_number, int8_t lut[1][1]) {
    
    return lut[mod_number][0];
};

} // namespace autotune

DebugSimulateEncodedSoftmaxContext::AutoTuneTableEntry
DebugSimulateEncodedSoftmaxContext::autotune_table[][ 3 ] = {
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A0__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A0__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A0__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A1__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A1__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A1__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A2__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A2__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A2__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A3__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A3__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A3__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A4__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A4__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A4__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A5__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A5__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A5__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A6__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A6__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A6__F2,
    },
    {
        &autotune::Autotune_debug_simulate_encoded_softmax__A7__F0,
        &autotune::Autotune_debug_simulate_encoded_softmax__A7__F1,
        &autotune::Autotune_debug_simulate_encoded_softmax__A7__F2,
    },
};

}

// vim: set fileencoding=utf-8

