// Copyright © 2025 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT

// clang-format off
#pragma once

#include <aotriton/config.h>
#include <aotriton/_internal/triton_kernel.h>
#include <aotriton/dtypes.h>
#include <aotriton/runtime.h>
#include <aotriton/util.h>
#include <aotriton/_internal/lazy_tensor_internal.h>
#include <functional>
#include <string>
#include <vector>
#include "aotriton/_internal/flash/aiter.h"

#if 1
namespace AOTRITON_NS::v3::flash {
    struct OpAttnBwdParams;
}
#endif

namespace AOTRITON_NS::v3::flash {

#if 1
using AOTRITON_NS::v3::flash::OpAttnBwdParams;
#else
// The parameter class must be defined here when
// There is no common operator for bwd_dq_dk_dv_v3.
struct OpAttnBwdParams {
    const TensorView<4>*   Q;
    const TensorView<4>*   K;
    const TensorView<4>*   V;
    const TensorView<4>*   B;
    float                  sm_scale;
    const TensorView<4>*   Out;
    const TensorView<4>*   DO;
    const TensorView<4>*   DK;
    const TensorView<4>*   DV;
    const TensorView<4>*   DQ;
    const TensorView<4>*   DB;
    LazyTensorInternal<4>* DQ_ACC;
    const TensorView<2>*   L;
    LazyTensorInternal<2>* D;
    int32_t                num_head_q;
    int32_t                num_head_k;
    const TensorView<1>*   cu_seqlens_q;
    const TensorView<1>*   cu_seqlens_k;
    int32_t                num_seqlens;
    int32_t                max_seqlen_q;
    int32_t                max_seqlen_k;
    int32_t                head_dim;
    float                  dropout_p;
    const TensorView<0>*   philox_seed_ptr;
    const TensorView<0>*   philox_offset1;
    uint64_t               philox_offset2;
    int32_t                Window_left;
    int32_t                Window_right;
    int16_t                BLOCK_DMODEL;
    int8_t                 CAUSAL_TYPE;
    bool                   ENABLE_DROPOUT;
    bool                   PADDED_HEAD;
    int8_t                 BIAS_TYPE;
};
#endif

struct BwdDqDkDvV3Context {
    const OpAttnBwdParams *params = nullptr;
    struct {
        bool   kIsUniformStride;
        int8_t MaskType;
        bool   kIsSEQPad;
        bool   kIsAtomic32;
        int8_t BF16Cvt;
        bool   kIsGroupMode;
    } residual_args;
    struct {
        int32_t ts_qo;
        int32_t ts_kv;
    } perf_args;
    const char* check_inputs_are_supported();
    void calculate_residual_func_fields();

    // Re-use TritonKernel class
    TritonKernel* kernel_on_device = nullptr;

    // Kernel arguments
    union DirectKernelArguments {
        AOTRITON_NS::v3::flash::aiter::fmha_bwd_v3_args fmha_bwd_v3_args;
        AOTRITON_NS::v3::flash::aiter::fmha_bwd_v3_gen_args fmha_bwd_v3_gen_args;
        AOTRITON_NS::v3::flash::aiter::fmha_bwd_v3_genl_args fmha_bwd_v3_genl_args;
        AOTRITON_NS::v3::flash::aiter::fmha_bwd_v3_group_args fmha_bwd_v3_group_args;
        AOTRITON_NS::v3::flash::aiter::fmha_bwd_v3_swa_genl_args fmha_bwd_v3_swa_genl_args;

    };
    typedef std::tuple<dim3, dim3>(BwdDqDkDvV3Context::*PP_FUNC)(DirectKernelArguments&) const;
    // These functions will be defined in
    // v3src/<family>/affine_<kernel_name>.cc
    std::tuple<dim3, dim3> pp_direct_kernel_args_for_fmha_bwd_v3_args(DirectKernelArguments&) const;
    std::tuple<dim3, dim3> pp_direct_kernel_args_for_fmha_bwd_v3_gen_args(DirectKernelArguments&) const;
    std::tuple<dim3, dim3> pp_direct_kernel_args_for_fmha_bwd_v3_genl_args(DirectKernelArguments&) const;
    std::tuple<dim3, dim3> pp_direct_kernel_args_for_fmha_bwd_v3_group_args(DirectKernelArguments&) const;
    std::tuple<dim3, dim3> pp_direct_kernel_args_for_fmha_bwd_v3_swa_genl_args(DirectKernelArguments&) const;
;
    PP_FUNC selected_pp_args;
    size_t sizeof_selected_args;

    // Kernel locator
    std::string_view affine_kernel_function_name;
    pstring_view package_path;
    std::string_view arch_name;
    // Note to save ELF space, this object is constructed on the fly.
    const char* _debug_kernel_name = nullptr;
#if AOTRITON_BUILD_FOR_TUNING
#endif

    hipError_t lookup_optimal(Gpu gpu);
    hipError_t launch(hipStream_t stream) const;

    std::function<dim3(const BwdDqDkDvV3Context&)> custom_grid_calculator;
    dim3 grid;

    int64_t godel_number() const;
    static std::tuple<int, int> get_archmod_number(Gpu gpu);
    static constexpr int kMaxGodelNumber = 576;

    typedef hipError_t (*CapabilityTableEntry)(BwdDqDkDvV3Context& context, int mod_number);
    static CapabilityTableEntry capability_table[][ kMaxGodelNumber ];
};

}

// vim: set fileencoding=utf-8


