/*
 * Copyright (c) 2015-2016, 2020-2024 The Khronos Group Inc.
 * Copyright (c) 2015-2016, 2020-2024 Valve Corporation
 * Copyright (c) 2015-2016, 2020-2024 LunarG, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <algorithm>
#include <cassert>
#include <iterator>
#include <memory>
#include <vector>

#include "generated/vk_function_pointers.h"
#include "generated/vk_extension_helper.h"
#include "test_common.h"

namespace vkt {

template <class Dst, class Src>
std::vector<Dst> MakeVkHandles(const std::vector<Src> &v) {
    std::vector<Dst> handles;
    handles.reserve(v.size());
    std::transform(v.begin(), v.end(), std::back_inserter(handles), [](const Src &o) { return o.handle(); });
    return handles;
}

template <class Dst, class Src>
std::vector<Dst> MakeVkHandles(const std::vector<Src *> &v) {
    std::vector<Dst> handles;
    handles.reserve(v.size());
    std::transform(v.begin(), v.end(), std::back_inserter(handles),
                   [](const Src *o) { return (o) ? o->handle() : VK_NULL_HANDLE; });
    return handles;
}

template <class Dst, class Src>
std::vector<Dst> MakeVkHandles(const vvl::span<Src *> &v) {
    std::vector<Dst> handles;
    handles.reserve(v.size());
    std::transform(v.begin(), v.end(), std::back_inserter(handles), [](const Src *o) { return o->handle(); });
    return handles;
}

class PhysicalDevice;
class Device;
class Queue;
class DeviceMemory;
class Fence;
class Semaphore;
class Event;
class QueryPool;
class Buffer;
class BufferView;
class Image;
class ImageView;
class DepthStencilView;
class PipelineCache;
class Pipeline;
class PipelineDelta;
class Sampler;
class DescriptorSetLayout;
class PipelineLayout;
class DescriptorSetPool;
class DescriptorSet;
class CommandBuffer;
class CommandPool;

std::vector<VkLayerProperties> GetGlobalLayers();
std::vector<VkExtensionProperties> GetGlobalExtensions();
std::vector<VkExtensionProperties> GetGlobalExtensions(const char *pLayerName);

namespace internal {

template <typename T>
class Handle {
  public:
    const T &handle() const noexcept { return handle_; }
    bool initialized() const noexcept { return (handle_ != T{}); }

    operator T() const noexcept { return handle(); }

    void SetName(VkDevice device, VkObjectType object_type, const char *name) {
        VkDebugUtilsObjectNameInfoEXT name_info = vku::InitStructHelper();
        name_info.objectType = object_type;
        name_info.objectHandle = CastToUint64(handle_);
        name_info.pObjectName = name;
        vk::SetDebugUtilsObjectNameEXT(device, &name_info);
    }

  protected:
    typedef T handle_type;

    explicit Handle() noexcept : handle_{} {}
    explicit Handle(T handle) noexcept : handle_(handle) {}

    // handles are non-copyable
    Handle(const Handle &) = delete;
    Handle &operator=(const Handle &) = delete;

    // handles can be moved out
    Handle(Handle &&src) noexcept : handle_{src.handle_} { src.handle_ = {}; }
    Handle &operator=(Handle &&src) noexcept {
        handle_ = src.handle_;
        src.handle_ = {};
        return *this;
    }

    void init(T handle) noexcept {
        assert(!initialized());
        handle_ = handle;
    }

  protected:
    T handle_;
};

template <typename T>
class NonDispHandle : public Handle<T> {
  protected:
    explicit NonDispHandle() noexcept : Handle<T>(), dev_handle_(VK_NULL_HANDLE) {}
    explicit NonDispHandle(VkDevice dev, T handle) noexcept : Handle<T>(handle), dev_handle_(dev) {}

    NonDispHandle(NonDispHandle &&src) noexcept : Handle<T>(std::move(src)) {
        dev_handle_ = src.dev_handle_;
        src.dev_handle_ = VK_NULL_HANDLE;
    }
    NonDispHandle &operator=(NonDispHandle &&src) noexcept {
        Handle<T>::operator=(std::move(src));
        dev_handle_ = src.dev_handle_;
        src.dev_handle_ = VK_NULL_HANDLE;
        return *this;
    }

    const VkDevice &device() const noexcept { return dev_handle_; }

    void init(VkDevice dev, T handle) noexcept {
        assert(!Handle<T>::initialized() && dev_handle_ == VK_NULL_HANDLE);
        Handle<T>::init(handle);
        dev_handle_ = dev;
    }

    void destroy() noexcept { dev_handle_ = VK_NULL_HANDLE; }

  public:
    void SetName(VkObjectType object_type, const char *name) { Handle<T>::SetName(dev_handle_, object_type, name); }

  private:
    VkDevice dev_handle_;
};

}  // namespace internal

class PhysicalDevice : public internal::Handle<VkPhysicalDevice> {
  public:
    explicit PhysicalDevice(VkPhysicalDevice phy)
        : Handle(phy),
          properties_(properties()),
          limits_(properties_.limits),
          memory_properties_(memory_properties()),
          queue_properties_(queue_properties()) {}

    void SetName(VkDevice device, const char *name) {
        Handle<VkPhysicalDevice>::SetName(device, VK_OBJECT_TYPE_PHYSICAL_DEVICE, name);
    }
    VkPhysicalDeviceFeatures features() const;

    bool set_memory_type(const uint32_t type_bits, VkMemoryAllocateInfo *info, const VkMemoryPropertyFlags properties,
                         const VkMemoryPropertyFlags forbid = 0) const;

    // vkEnumerateDeviceExtensionProperties()
    std::vector<VkExtensionProperties> extensions(const char *pLayerName = nullptr) const;

    // vkEnumerateLayers()
    std::vector<VkLayerProperties> layers() const;

    const VkPhysicalDeviceProperties properties_;
    const VkPhysicalDeviceLimits limits_;
    const VkPhysicalDeviceMemoryProperties memory_properties_;
    const std::vector<VkQueueFamilyProperties> queue_properties_;

  private:
    void add_extension_dependencies(uint32_t dependency_count, VkExtensionProperties *depencency_props,
                                    std::vector<VkExtensionProperties> &ext_list);
    VkPhysicalDeviceProperties properties() const;
    std::vector<VkQueueFamilyProperties> queue_properties() const;
    VkPhysicalDeviceMemoryProperties memory_properties() const;
};

class QueueCreateInfoArray {
  private:
    std::vector<VkDeviceQueueCreateInfo> queue_info_;
    std::vector<std::vector<float>> queue_priorities_;

  public:
    QueueCreateInfoArray(const std::vector<VkQueueFamilyProperties> &queue_props, bool all_queue_count = false);
    size_t size() const { return queue_info_.size(); }
    const VkDeviceQueueCreateInfo *data() const { return queue_info_.data(); }
};

class Device : public internal::Handle<VkDevice> {
  public:
    explicit Device(VkPhysicalDevice phy) : phy_(phy) { init(); }
    explicit Device(VkPhysicalDevice phy, const VkDeviceCreateInfo &info) : phy_(phy) { init(info); }
    explicit Device(VkPhysicalDevice phy, std::vector<const char *> &extension_names, VkPhysicalDeviceFeatures *features = nullptr,
                    void *create_device_pnext = nullptr, bool all_queue_count = false)
        : phy_(phy) {
        init(extension_names, features, create_device_pnext, all_queue_count);
    }

    ~Device() noexcept;
    void destroy() noexcept;

    // vkCreateDevice()
    void init(const VkDeviceCreateInfo &info);
    void init(std::vector<const char *> &extensions, VkPhysicalDeviceFeatures *features = nullptr,
              void *create_device_pnext = nullptr, bool all_queue_count = false);  // all queues, all extensions, etc
    void init() {
        std::vector<const char *> extensions;
        init(extensions);
    };

    void SetName(const char *name) { Handle<VkDevice>::SetName(handle_, VK_OBJECT_TYPE_DEVICE, name); }
    const PhysicalDevice &phy() const { return phy_; }

    std::vector<const char *> GetEnabledExtensions() { return enabled_extensions_; }
    bool IsEnabledExtension(const char *extension) const;

    // vkGetDeviceProcAddr()
    PFN_vkVoidFunction get_proc(const char *name) const { return vk::GetDeviceProcAddr(handle(), name); }

    const std::vector<Queue *> &QueuesWithGraphicsCapability() const { return queues_[GRAPHICS]; }
    const std::vector<Queue *> &QueuesWithComputeCapability() const { return queues_[COMPUTE]; }
    const std::vector<Queue *> &QueuesWithTransferCapability() const { return queues_[TRANSFER]; }
    const std::vector<Queue *> &QueuesWithSparseCapability() const { return queues_[SPARSE]; }

    using QueueFamilyQueues = std::vector<std::unique_ptr<Queue>>;
    const QueueFamilyQueues &QueuesFromFamily(uint32_t queue_family) const;

    // Queue family that has "with" capabilities and optionally without "without" capabilities.
    std::optional<uint32_t> QueueFamily(VkQueueFlags with, VkQueueFlags without = 0) const;

    // Queue family that does not have "without" capabilities
    std::optional<uint32_t> QueueFamilyWithoutCapabilities(VkQueueFlags without) const;
    Queue *QueueWithoutCapabilities(VkQueueFlags without) const;

    // Dedicated compute queue family: has compute but no graphics
    std::optional<uint32_t> ComputeOnlyQueueFamily() const;
    Queue *ComputeOnlyQueue() const;

    // Dedicated transfer queue family: has tranfer but no graphics/compute
    std::optional<uint32_t> TransferOnlyQueueFamily() const;
    Queue *TransferOnlyQueue() const;

    // Compute or transfer
    std::optional<uint32_t> NonGraphicsQueueFamily() const;
    Queue *NonGraphicsQueue() const;

    uint32_t graphics_queue_node_index_ = vvl::kU32Max;

    const PhysicalDevice phy_;

    struct Format {
        VkFormat format;
        VkImageTiling tiling;
        VkFlags features;
    };

    VkFormatFeatureFlags2 FormatFeaturesLinear(VkFormat format) const;
    VkFormatFeatureFlags2 FormatFeaturesOptimal(VkFormat format) const;
    VkFormatFeatureFlags2 FormatFeaturesBuffer(VkFormat format) const;

    // vkDeviceWaitIdle()
    void Wait() const;

    // vkWaitForFences()
    VkResult Wait(const std::vector<const Fence *> &fences, bool wait_all, uint64_t timeout);
    VkResult Wait(const Fence &fence) { return Wait(std::vector<const Fence *>(1, &fence), true, (uint64_t)-1); }

    // vkUpdateDescriptorSets()
    void update_descriptor_sets(const std::vector<VkWriteDescriptorSet> &writes, const std::vector<VkCopyDescriptorSet> &copies);
    void update_descriptor_sets(const std::vector<VkWriteDescriptorSet> &writes) {
        return update_descriptor_sets(writes, std::vector<VkCopyDescriptorSet>());
    }

    static VkWriteDescriptorSet write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                     VkDescriptorType type, uint32_t count,
                                                     const VkDescriptorImageInfo *image_info);
    static VkWriteDescriptorSet write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                     VkDescriptorType type, uint32_t count,
                                                     const VkDescriptorBufferInfo *buffer_info);
    static VkWriteDescriptorSet write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                     VkDescriptorType type, uint32_t count, const VkBufferView *buffer_views);
    static VkWriteDescriptorSet write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                     VkDescriptorType type, const std::vector<VkDescriptorImageInfo> &image_info);
    static VkWriteDescriptorSet write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                     VkDescriptorType type, const std::vector<VkDescriptorBufferInfo> &buffer_info);
    static VkWriteDescriptorSet write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                     VkDescriptorType type, const std::vector<VkBufferView> &buffer_views);

    static VkCopyDescriptorSet copy_descriptor_set(const DescriptorSet &src_set, uint32_t src_binding, uint32_t src_array_element,
                                                   const DescriptorSet &dst_set, uint32_t dst_binding, uint32_t dst_array_element,
                                                   uint32_t count);

  private:
    enum QueueCapabilityIndex {
        GRAPHICS = 0,
        COMPUTE = 1,
        TRANSFER = 2,
        SPARSE = 3,
        QUEUE_CAPABILITY_COUNT = 4,
    };

    void init_queues(const VkDeviceCreateInfo &info);

    std::vector<const char *> enabled_extensions_;

    std::vector<QueueFamilyQueues> queue_families_;
    std::vector<Queue *> queues_[QUEUE_CAPABILITY_COUNT];
};

class DeviceMemory : public internal::NonDispHandle<VkDeviceMemory> {
  public:
    DeviceMemory() = default;
    DeviceMemory(const Device &dev, const VkMemoryAllocateInfo &info) { init(dev, info); }
    ~DeviceMemory() noexcept;
    void destroy() noexcept;
    DeviceMemory &operator=(DeviceMemory &&) = default;

    // vkAllocateMemory()
    // Fails the test when allocation is unsuccessful
    void init(const Device &dev, const VkMemoryAllocateInfo &info);
    // Does not fail the test when allocation is unsuccessful and instead returns error code
    VkResult try_init(const Device &dev, const VkMemoryAllocateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkDeviceMemory>::SetName(VK_OBJECT_TYPE_DEVICE_MEMORY, name); }

    // vkMapMemory()
    const void *map(VkFlags flags) const;
    void *map(VkFlags flags);
    const void *map() const { return map(0); }
    void *map() { return map(0); }

    // vkUnmapMemory()
    void unmap() const;
	const auto &get_memory_allocate_info() { return memory_allocate_info_; }

        static VkMemoryAllocateInfo get_resource_alloc_info(const Device &dev, const VkMemoryRequirements &reqs,
                                                            VkMemoryPropertyFlags mem_props, void *alloc_info_pnext = nullptr);

      private:
        VkMemoryAllocateInfo memory_allocate_info_{};
};

class Fence : public internal::NonDispHandle<VkFence> {
  public:
    Fence() = default;
    Fence(const Device &dev) { init(dev, create_info()); }
    Fence(const Device &dev, const VkFenceCreateInfo &info) { init(dev, info); }
    Fence(Fence &&rhs) noexcept : NonDispHandle(std::move(rhs)) {}
    Fence &operator=(Fence &&) noexcept;
    ~Fence() noexcept;
    void destroy() noexcept;

    // vkCreateFence()
    void init(const Device &dev, const VkFenceCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkFence>::SetName(VK_OBJECT_TYPE_FENCE, name); }

    // vkGetFenceStatus()
    VkResult status() const { return vk::GetFenceStatus(device(), handle()); }
    VkResult wait(uint64_t timeout) const;

    VkResult reset();

#ifdef VK_USE_PLATFORM_WIN32_KHR
    VkResult export_handle(HANDLE &win32_handle, VkExternalFenceHandleTypeFlagBits handle_type);
    VkResult import_handle(HANDLE win32_handle, VkExternalFenceHandleTypeFlagBits handle_type, VkFenceImportFlags flags = 0);
#endif
    VkResult export_handle(int &fd_handle, VkExternalFenceHandleTypeFlagBits handle_type);
    VkResult import_handle(int fd_handle, VkExternalFenceHandleTypeFlagBits handle_type, VkFenceImportFlags flags = 0);

    static VkFenceCreateInfo create_info(VkFenceCreateFlags flags);
    static VkFenceCreateInfo create_info();
};

inline const Fence no_fence;

class Semaphore : public internal::NonDispHandle<VkSemaphore> {
  public:
    Semaphore() = default;
    Semaphore(const Device &dev, VkSemaphoreType type = VK_SEMAPHORE_TYPE_BINARY, uint64_t initial_value = 0);
    Semaphore(const Device &dev, const VkSemaphoreCreateInfo &info) { init(dev, info); }
    Semaphore(Semaphore &&rhs) noexcept : NonDispHandle(std::move(rhs)) {}
    Semaphore &operator=(Semaphore &&) noexcept;
    ~Semaphore() noexcept;
    void destroy() noexcept;

    // vkCreateSemaphore()
    void init(const Device &dev, const VkSemaphoreCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkSemaphore>::SetName(VK_OBJECT_TYPE_SEMAPHORE, name); }

    VkResult Wait(uint64_t value, uint64_t timeout);
    VkResult WaitKHR(uint64_t value, uint64_t timeout);
    VkResult Signal(uint64_t value);
    VkResult SignalKHR(uint64_t value);

#ifdef VK_USE_PLATFORM_WIN32_KHR
    VkResult export_handle(HANDLE &win32_handle, VkExternalSemaphoreHandleTypeFlagBits handle_type);
    VkResult import_handle(HANDLE win32_handle, VkExternalSemaphoreHandleTypeFlagBits handle_type,
                           VkSemaphoreImportFlags flags = 0);
#endif
    VkResult export_handle(int &fd_handle, VkExternalSemaphoreHandleTypeFlagBits handle_type);
    VkResult import_handle(int fd_handle, VkExternalSemaphoreHandleTypeFlagBits handle_type, VkSemaphoreImportFlags flags = 0);
};

class Event : public internal::NonDispHandle<VkEvent> {
  public:
    Event() = default;
    Event(const Device &dev) { init(dev, vku::InitStruct<VkEventCreateInfo>()); }
    Event(const Device &dev, const VkEventCreateInfo &info) { init(dev, info); }
    ~Event() noexcept;
    void destroy() noexcept;

    // vkCreateEvent()
    void init(const Device &dev, const VkEventCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkEvent>::SetName(VK_OBJECT_TYPE_EVENT, name); }

    // vkGetEventStatus()
    // vkSetEvent()
    // vkResetEvent()
    VkResult status() const { return vk::GetEventStatus(device(), handle()); }
    void set();
    void cmd_set(const CommandBuffer &cmd, VkPipelineStageFlags stage_mask);
    void cmd_reset(const CommandBuffer &cmd, VkPipelineStageFlags stage_mask);
    void cmd_wait(const CommandBuffer &cmd, VkPipelineStageFlags src_stage_mask, VkPipelineStageFlags dst_stage_mask,
                  const std::vector<VkMemoryBarrier> &memory_barriers, const std::vector<VkBufferMemoryBarrier> &buffer_barriers,
                  const std::vector<VkImageMemoryBarrier> &image_barriers);
    void reset();

    static VkEventCreateInfo create_info(VkFlags flags);
};

struct WaitT {};
constexpr WaitT wait{};
struct SignalT {};
constexpr SignalT signal{};

class Queue : public internal::Handle<VkQueue> {
  public:
    explicit Queue(VkQueue queue, uint32_t index) : Handle(queue), family_index(index) {}
    void SetName(const Device &device, const char *name) { Handle<VkQueue>::SetName(device.handle(), VK_OBJECT_TYPE_QUEUE, name); }

    // vkQueueSubmit()
    VkResult Submit(const CommandBuffer &cmd, const Fence &fence = no_fence);
    VkResult Submit(const vvl::span<CommandBuffer *> &cmds, const Fence &fence = no_fence);

    VkResult Submit(const CommandBuffer &cmd, WaitT, const Semaphore &wait_semaphore,
                    VkPipelineStageFlags wait_stage_mask = VK_PIPELINE_STAGE_ALL_COMMANDS_BIT, const Fence &fence = no_fence);
    VkResult Submit(const CommandBuffer &cmd, SignalT, const Semaphore &signal_semaphore, const Fence &fence = no_fence);
    VkResult Submit(const CommandBuffer &cmd, const Semaphore &wait_semaphore, VkPipelineStageFlags wait_stage_mask,
                    const Semaphore &signal_semaphore, const Fence &fence = no_fence);

    VkResult SubmitWithTimelineSemaphore(const CommandBuffer &cmd, WaitT, const Semaphore &wait_semaphore, uint64_t wait_value,
                                         VkPipelineStageFlags wait_stage_mask = VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
                                         const Fence &fence = no_fence);
    VkResult SubmitWithTimelineSemaphore(const CommandBuffer &cmd, SignalT, const Semaphore &signal_semaphore,
                                         uint64_t signal_value, const Fence &fence = no_fence);
    VkResult SubmitWithTimelineSemaphore(const CommandBuffer &cmd, const Semaphore &wait_semaphore, uint64_t wait_value,
                                         const Semaphore &signal_semaphore, uint64_t signal_value, const Fence &fence = no_fence);

    // vkQueueSubmit2()
    VkResult Submit2(const CommandBuffer &cmd, const Fence &fence = no_fence, bool use_khr = false);
    VkResult Submit2(const vvl::span<const CommandBuffer> &cmds, const Fence &fence = no_fence, bool use_khr = false);

    VkResult Submit2(const CommandBuffer &cmd, WaitT, const Semaphore &wait_semaphore,
                     VkPipelineStageFlags2 wait_stage_mask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, const Fence &fence = no_fence,
                     bool use_khr = false);
    VkResult Submit2(const CommandBuffer &cmd, SignalT, const Semaphore &signal_semaphore,
                     VkPipelineStageFlags2 signal_stage_mask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, const Fence &fence = no_fence,
                     bool use_khr = false);
    VkResult Submit2(const CommandBuffer &cmd, const Semaphore &wait_semaphore, VkPipelineStageFlags2 wait_stage_mask,
                     const Semaphore &signal_semaphore, VkPipelineStageFlags2 signal_stage_mask, const Fence &fence = no_fence,
                     bool use_khr = false);

    VkResult Submit2WithTimelineSemaphore(const CommandBuffer &cmd, WaitT, const Semaphore &wait_semaphore, uint64_t value,
                                          VkPipelineStageFlags2 wait_stage_mask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT,
                                          const Fence &fence = no_fence, bool use_khr = false);
    VkResult Submit2WithTimelineSemaphore(const CommandBuffer &cmd, SignalT, const Semaphore &signal_semaphore, uint64_t value,
                                          VkPipelineStageFlags2 signal_stage_mask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT,
                                          const Fence &fence = no_fence, bool use_khr = false);
    VkResult Submit2WithTimelineSemaphore(const CommandBuffer &cmd, const Semaphore &wait_semaphore, uint64_t wait_value,
                                          const Semaphore &signal_semaphore, uint64_t signal_value, const Fence &fence = no_fence,
                                          bool use_khr = false);

    // vkQueueWaitIdle()
    VkResult Wait();

    const uint32_t family_index;
};

class QueryPool : public internal::NonDispHandle<VkQueryPool> {
  public:
    QueryPool() = default;
    QueryPool(const Device &dev, const VkQueryPoolCreateInfo &info) { init(dev, info); }
    QueryPool(const Device &dev, VkQueryType query_type, uint32_t query_count) {
        VkQueryPoolCreateInfo info = create_info(query_type, query_count);
        init(dev, info);
    }
    ~QueryPool() noexcept;
    void destroy() noexcept;

    // vkCreateQueryPool()
    void init(const Device &dev, const VkQueryPoolCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkQueryPool>::SetName(VK_OBJECT_TYPE_QUERY_POOL, name); }

    // vkGetQueryPoolResults()
    VkResult results(uint32_t first, uint32_t count, size_t size, void *data, size_t stride);

    static VkQueryPoolCreateInfo create_info(VkQueryType type, uint32_t slot_count);
};

struct NoMemT {};
static constexpr NoMemT no_mem{};
struct DeviceAddressT {};
static constexpr DeviceAddressT device_address{};
struct SetLayoutT {};
static constexpr SetLayoutT set_layout{};

class Buffer : public internal::NonDispHandle<VkBuffer> {
  public:
    explicit Buffer() : NonDispHandle(), create_info_(vku::InitStruct<decltype(create_info_)>()) {}
    explicit Buffer(const Device &dev, const VkBufferCreateInfo &info, VkMemoryPropertyFlags mem_props = 0,
                    void *alloc_info_pnext = nullptr) {
        init(dev, info, mem_props, alloc_info_pnext);
    }
    explicit Buffer(const Device &dev, VkDeviceSize size, VkBufferUsageFlags usage, VkMemoryPropertyFlags mem_props = 0,
                    void *alloc_info_pnext = nullptr) {
        init(dev, size, usage, mem_props, alloc_info_pnext);
    }
    explicit Buffer(const Device &dev, const VkBufferCreateInfo &info, NoMemT) { init_no_mem(dev, info); }

    // Various spots need a host visible buffer they can call GetBufferDeviceAddress on
    explicit Buffer(const Device &dev, VkDeviceSize size, VkBufferUsageFlags usage, DeviceAddressT) {
        usage |= VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT;  // always add
        VkMemoryAllocateFlagsInfo allocate_flag_info = vku::InitStructHelper();
        allocate_flag_info.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT;
        init(dev, create_info(size, usage), VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
             &allocate_flag_info);
    }

    Buffer(Buffer &&rhs) noexcept : NonDispHandle(std::move(rhs)) {
        create_info_ = std::move(rhs.create_info_);
        internal_mem_ = std::move(rhs.internal_mem_);
    }
    Buffer &operator=(Buffer &&rhs) noexcept {
        if (&rhs == this) {
            return *this;
        }
        destroy();
        internal_mem_.destroy();
        NonDispHandle::operator=(std::move(rhs));
        create_info_ = std::move(rhs.create_info_);
        internal_mem_ = std::move(rhs.internal_mem_);
        return *this;
    }
    ~Buffer() noexcept;
    void destroy() noexcept;

    // vkCreateBuffer()
    void init(const Device &dev, const VkBufferCreateInfo &info, VkMemoryPropertyFlags mem_props = 0,
              void *alloc_info_pnext = nullptr);
    void init(const Device &dev, VkDeviceSize size, VkBufferUsageFlags usage, VkMemoryPropertyFlags mem_props = 0,
              void *alloc_info_pnext = nullptr, const std::vector<uint32_t> &queue_families = {}) {
        init(dev, create_info(size, usage, &queue_families), mem_props, alloc_info_pnext);
    }
    void init_no_mem(const Device &dev, const VkBufferCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkBuffer>::SetName(VK_OBJECT_TYPE_BUFFER, name); }

    // get the internal memory
    const DeviceMemory &memory() const { return internal_mem_; }
    DeviceMemory &memory() { return internal_mem_; }

    // vkGetObjectMemoryRequirements()
    VkMemoryRequirements memory_requirements() const;

    // Allocate and bind memory
    // The assumption that this object was created in no_mem configuration
    void allocate_and_bind_memory(const Device &dev, VkMemoryPropertyFlags mem_props = 0, void *alloc_info_pnext = nullptr);

    // Bind to existing memory object
    void bind_memory(const DeviceMemory &mem, VkDeviceSize mem_offset);

    const VkBufferCreateInfo &create_info() const { return create_info_; }
    static VkBufferCreateInfo create_info(VkDeviceSize size, VkFlags usage, const std::vector<uint32_t> *queue_families = nullptr,
                                          void *create_info_pnext = nullptr);

    VkBufferMemoryBarrier buffer_memory_barrier(VkFlags output_mask, VkFlags input_mask, VkDeviceSize offset,
                                                VkDeviceSize size) const {
        VkBufferMemoryBarrier barrier = vku::InitStructHelper();
        barrier.buffer = handle();
        barrier.srcAccessMask = output_mask;
        barrier.dstAccessMask = input_mask;
        barrier.offset = offset;
        barrier.size = size;
        if (create_info_.sharingMode == VK_SHARING_MODE_CONCURRENT) {
            barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
            barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        }
        return barrier;
    }

    VkBufferMemoryBarrier2KHR buffer_memory_barrier(VkPipelineStageFlags2KHR src_stage, VkPipelineStageFlags2KHR dst_stage,
                                                    VkAccessFlags2KHR src_access, VkAccessFlags2KHR dst_access, VkDeviceSize offset,
                                                    VkDeviceSize size) const {
        VkBufferMemoryBarrier2KHR barrier = vku::InitStructHelper();
        barrier.buffer = handle();
        barrier.srcStageMask = src_stage;
        barrier.dstStageMask = dst_stage;
        barrier.srcAccessMask = src_access;
        barrier.dstAccessMask = dst_access;
        barrier.offset = offset;
        barrier.size = size;
        if (create_info_.sharingMode == VK_SHARING_MODE_CONCURRENT) {
            barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
            barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        }
        return barrier;
    }

    [[nodiscard]] VkDeviceAddress address() const;

  private:
    VkBufferCreateInfo create_info_;
    DeviceMemory internal_mem_;
};

class BufferView : public internal::NonDispHandle<VkBufferView> {
  public:
    BufferView() = default;
    BufferView(const Device &dev, const VkBufferViewCreateInfo &info) { init(dev, info); }
    ~BufferView() noexcept;
    void destroy() noexcept;

    // vkCreateBufferView()
    void init(const Device &dev, const VkBufferViewCreateInfo &info);
    static VkBufferViewCreateInfo createInfo(VkBuffer buffer, VkFormat format, VkDeviceSize offset = 0,
                                             VkDeviceSize range = VK_WHOLE_SIZE);
    void SetName(const char *name) { NonDispHandle<VkBufferView>::SetName(VK_OBJECT_TYPE_BUFFER_VIEW, name); }
};

inline VkBufferViewCreateInfo BufferView::createInfo(VkBuffer buffer, VkFormat format, VkDeviceSize offset, VkDeviceSize range) {
    VkBufferViewCreateInfo info = vku::InitStructHelper();
    info.flags = VkFlags(0);
    info.buffer = buffer;
    info.format = format;
    info.offset = offset;
    info.range = range;
    return info;
}

class Image : public internal::NonDispHandle<VkImage> {
  public:
    explicit Image() : NonDispHandle() {}
    explicit Image(const Device &dev, const VkImageCreateInfo &info);
    explicit Image(const Device &dev, const VkImageCreateInfo &info, VkMemoryPropertyFlags mem_props,
                   void *alloc_info_pnext = nullptr);
    explicit Image(const Device &dev, uint32_t const width, uint32_t const height, uint32_t const mip_levels, VkFormat const format,
                   VkFlags const usage);

    explicit Image(const Device &dev, const VkImageCreateInfo &info, NoMemT);
    explicit Image(const Device &dev, const VkImageCreateInfo &info, SetLayoutT);

    ~Image() noexcept;
    void destroy() noexcept;

    void init(const Device &dev, const VkImageCreateInfo &info, VkMemoryPropertyFlags mem_props, void *alloc_info_pnext = nullptr);
    void Init(const Device &dev, uint32_t const width, uint32_t const height, uint32_t const mip_levels, VkFormat const format,
              VkFlags const usage);
    void init_no_mem(const Device &dev, const VkImageCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkImage>::SetName(VK_OBJECT_TYPE_IMAGE, name); }

    static VkImageCreateInfo ImageCreateInfo2D(uint32_t const width, uint32_t const height, uint32_t const mip_levels,
                                               uint32_t const layers, VkFormat const format, VkFlags const usage,
                                               VkImageTiling const requested_tiling = VK_IMAGE_TILING_OPTIMAL,
                                               const std::vector<uint32_t> *queue_families = nullptr);

    static bool IsCompatible(const Device &dev, VkImageUsageFlags usages, VkFormatFeatureFlags2 features);

    // get the internal memory
    const DeviceMemory &memory() const { return internal_mem_; }
    DeviceMemory &memory() { return internal_mem_; }

    // vkGetObjectMemoryRequirements()
    VkMemoryRequirements memory_requirements() const;

    // Allocate and bind memory
    // The assumption that this object was created in no_mem configuration
    void allocate_and_bind_memory(const Device &dev, VkMemoryPropertyFlags mem_props = 0, void *alloc_info_pnext = nullptr);

    // Bind to existing memory object
    void bind_memory(const DeviceMemory &mem, VkDeviceSize mem_offset);

    uint32_t width() const { return create_info_.extent.width; }
    uint32_t height() const { return create_info_.extent.height; }
    VkFormat format() const { return create_info_.format; }
    VkImageUsageFlags usage() const { return create_info_.usage; }

    VkImageMemoryBarrier image_memory_barrier(VkFlags output_mask, VkFlags input_mask, VkImageLayout old_layout,
                                              VkImageLayout new_layout, const VkImageSubresourceRange &range) const {
        VkImageMemoryBarrier barrier = vku::InitStructHelper();
        barrier.srcAccessMask = output_mask;
        barrier.dstAccessMask = input_mask;
        barrier.oldLayout = old_layout;
        barrier.newLayout = new_layout;
        barrier.image = handle();
        barrier.subresourceRange = range;
        barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        return barrier;
    }

    VkImageMemoryBarrier2KHR image_memory_barrier(VkPipelineStageFlags2KHR src_stage, VkPipelineStageFlags2KHR dst_stage,
                                                  VkAccessFlags2KHR src_access, VkAccessFlags2KHR dst_access,
                                                  VkImageLayout old_layout, VkImageLayout new_layout,
                                                  const VkImageSubresourceRange &range) const {
        VkImageMemoryBarrier2KHR barrier = vku::InitStructHelper();
        barrier.srcStageMask = src_stage;
        barrier.dstStageMask = dst_stage;
        barrier.srcAccessMask = src_access;
        barrier.dstAccessMask = dst_access;
        barrier.oldLayout = old_layout;
        barrier.newLayout = new_layout;
        barrier.image = handle();
        barrier.subresourceRange = range;
        barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED;
        return barrier;
    }

    void ImageMemoryBarrier(CommandBuffer *cmd, VkImageAspectFlags aspect, VkFlags output_mask, VkFlags input_mask,
                            VkImageLayout image_layout, VkPipelineStageFlags src_stages = VK_PIPELINE_STAGE_ALL_COMMANDS_BIT,
                            VkPipelineStageFlags dest_stages = VK_PIPELINE_STAGE_ALL_COMMANDS_BIT);

    static VkImageCreateInfo create_info();

    static VkImageSubresource subresource(VkImageAspectFlags aspect, uint32_t mip_level, uint32_t array_layer);
    static VkImageSubresource subresource(const VkImageSubresourceRange &range, uint32_t mip_level, uint32_t array_layer);
    static VkImageSubresourceLayers subresource(VkImageAspectFlags aspect, uint32_t mip_level, uint32_t array_layer,
                                                uint32_t array_size);
    static VkImageSubresourceLayers subresource(const VkImageSubresourceRange &range, uint32_t mip_level, uint32_t array_layer,
                                                uint32_t array_size);

    VkImageSubresourceRange subresource_range(VkImageAspectFlags aspect) const { return subresource_range(create_info_, aspect); }
    static VkImageSubresourceRange subresource_range(VkImageAspectFlags aspect_mask, uint32_t base_mip_level, uint32_t mip_levels,
                                                     uint32_t base_array_layer, uint32_t num_layers);
    static VkImageSubresourceRange subresource_range(const VkImageCreateInfo &info, VkImageAspectFlags aspect_mask);
    static VkImageSubresourceRange subresource_range(const VkImageSubresource &subres);

    static VkImageAspectFlags aspect_mask(VkFormat format);

    void Layout(VkImageLayout const layout) { image_layout_ = layout; }
    VkImageLayout Layout() const { return image_layout_; }

    void SetLayout(CommandBuffer *cmd_buf, VkImageAspectFlags aspect, VkImageLayout image_layout);
    void SetLayout(VkImageAspectFlags aspect, VkImageLayout image_layout);
    void SetLayout(VkImageLayout image_layout) { SetLayout(aspect_mask(format()), image_layout); };

    VkImageViewCreateInfo BasicViewCreatInfo(VkImageAspectFlags aspect_mask = VK_IMAGE_ASPECT_COLOR_BIT) const;
    ImageView CreateView(VkImageAspectFlags aspect = VK_IMAGE_ASPECT_COLOR_BIT, void *pNext = nullptr) const;
    ImageView CreateView(VkImageViewType type, uint32_t baseMipLevel = 0, uint32_t levelCount = VK_REMAINING_MIP_LEVELS,
                         uint32_t baseArrayLayer = 0, uint32_t layerCount = VK_REMAINING_ARRAY_LAYERS,
                         VkImageAspectFlags aspect = VK_IMAGE_ASPECT_COLOR_BIT) const;

  private:
    // We need this to do ImageView and SetLayout actions
    const Device *device_ = nullptr;

    VkImageCreateInfo create_info_;

    DeviceMemory internal_mem_;
    VkImageLayout image_layout_ = VK_IMAGE_LAYOUT_GENERAL;
};

class ImageView : public internal::NonDispHandle<VkImageView> {
  public:
    explicit ImageView() = default;
    explicit ImageView(const Device &dev, const VkImageViewCreateInfo &info) { init(dev, info); }
    ImageView(ImageView &&rhs) noexcept : NonDispHandle(std::move(rhs)) {}
    ImageView &operator=(ImageView &&src) noexcept {
        this->~ImageView();
        this->NonDispHandle::operator=(std::move(src));
        return *this;
    }
    ~ImageView() noexcept;
    void destroy() noexcept;

    // vkCreateImageView()
    void init(const Device &dev, const VkImageViewCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkImageView>::SetName(VK_OBJECT_TYPE_IMAGE_VIEW, name); }
};

class AccelerationStructureNV : public internal::NonDispHandle<VkAccelerationStructureNV> {
  public:
    explicit AccelerationStructureNV(const Device &dev, const VkAccelerationStructureCreateInfoNV &info, bool init_memory = true) {
        init(dev, info, init_memory);
    }
    ~AccelerationStructureNV() noexcept;
    void destroy() noexcept;

    // vkCreateAccelerationStructureNV
    void init(const Device &dev, const VkAccelerationStructureCreateInfoNV &info, bool init_memory = true);
    void SetName(const char *name) {
        NonDispHandle<VkAccelerationStructureNV>::SetName(VK_OBJECT_TYPE_ACCELERATION_STRUCTURE_NV, name);
    }
    // vkGetAccelerationStructureMemoryRequirementsNV()
    VkMemoryRequirements2 memory_requirements() const;
    VkMemoryRequirements2 build_scratch_memory_requirements() const;

    uint64_t opaque_handle() const { return opaque_handle_; }

    const VkAccelerationStructureInfoNV &info() const { return info_; }

    const VkDevice &dev() const { return device(); }

    [[nodiscard]] Buffer create_scratch_buffer(const Device &device, VkBufferCreateInfo *pCreateInfo = nullptr,
                                               bool buffer_device_address = false) const;

  private:
    VkAccelerationStructureInfoNV info_;
    DeviceMemory memory_;
    uint64_t opaque_handle_;
};

class ShaderModule : public internal::NonDispHandle<VkShaderModule> {
  public:
    ShaderModule() = default;
    ShaderModule(const Device &dev, const VkShaderModuleCreateInfo &info) { init(dev, info); }
    ~ShaderModule() noexcept;
    void destroy() noexcept;

    // vkCreateShaderModule()
    void init(const Device &dev, const VkShaderModuleCreateInfo &info);
    VkResult init_try(const Device &dev, const VkShaderModuleCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkShaderModule>::SetName(VK_OBJECT_TYPE_SHADER_MODULE, name); }

    static VkShaderModuleCreateInfo create_info(size_t code_size, const uint32_t *code, VkFlags flags);
};

class Shader : public internal::NonDispHandle<VkShaderEXT> {
  public:
    Shader(const Device &dev, const VkShaderCreateInfoEXT &info) { init(dev, info); }
    Shader(const Device &dev, const VkShaderStageFlagBits stage, const std::vector<uint32_t> &spv,
           const VkDescriptorSetLayout *descriptorSetLayout = nullptr, const VkPushConstantRange* pushConstRange = nullptr);
    Shader(const Device &dev, const VkShaderStageFlagBits stage, const std::vector<uint8_t> &binary,
           const VkDescriptorSetLayout *descriptorSetLayout = nullptr, const VkPushConstantRange *pushConstRange = nullptr);
    ~Shader() noexcept;
    void destroy() noexcept;

    // vkCreateShaderModule()
    void init(const Device &dev, const VkShaderCreateInfoEXT &info);
    VkResult init_try(const Device &dev, const VkShaderCreateInfoEXT &info);
    void SetName(const char *name) { NonDispHandle<VkShaderEXT>::SetName(VK_OBJECT_TYPE_SHADER_EXT, name); }
};

class PipelineCache : public internal::NonDispHandle<VkPipelineCache> {
  public:
    PipelineCache() = default;
    PipelineCache(const Device &dev, const VkPipelineCacheCreateInfo &info) { init(dev, info); }
    ~PipelineCache() noexcept;
    void destroy() noexcept;

    void init(const Device &dev, const VkPipelineCacheCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkPipelineCache>::SetName(VK_OBJECT_TYPE_PIPELINE_CACHE, name); }
};

class Pipeline : public internal::NonDispHandle<VkPipeline> {
  public:
    Pipeline() = default;
    Pipeline(const Device &dev, const VkGraphicsPipelineCreateInfo &info) { init(dev, info); }
    Pipeline(const Device &dev, const VkGraphicsPipelineCreateInfo &info, const VkPipeline basePipeline) {
        init(dev, info, basePipeline);
    }
    Pipeline(const Device &dev, const VkComputePipelineCreateInfo &info) { init(dev, info); }
    Pipeline(const Device &dev, const VkRayTracingPipelineCreateInfoKHR &info) { init(dev, info); }
    ~Pipeline() noexcept;
    void destroy() noexcept;

    // vkCreateGraphicsPipeline()
    void init(const Device &dev, const VkGraphicsPipelineCreateInfo &info);
    // vkCreateGraphicsPipelineDerivative()
    void init(const Device &dev, const VkGraphicsPipelineCreateInfo &info, const VkPipeline basePipeline);
    // vkCreateComputePipeline()
    void init(const Device &dev, const VkComputePipelineCreateInfo &info);
    // vkCreateRayTracingPipelinesKHR
    void init(const Device &dev, const VkRayTracingPipelineCreateInfoKHR &info);
    // vkLoadPipeline()
    void init(const Device &dev, size_t size, const void *data);
    // vkLoadPipelineDerivative()
    void init(const Device &dev, size_t size, const void *data, VkPipeline basePipeline);

    // vkCreateGraphicsPipeline with error return
    VkResult init_try(const Device &dev, const VkGraphicsPipelineCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkPipeline>::SetName(VK_OBJECT_TYPE_PIPELINE, name); }

    // vkStorePipeline()
    size_t store(size_t size, void *data);
};

class PipelineLayout : public internal::NonDispHandle<VkPipelineLayout> {
  public:
    PipelineLayout() noexcept : NonDispHandle() {}
    PipelineLayout(const Device &dev, VkPipelineLayoutCreateInfo &info,
                   const std::vector<const DescriptorSetLayout *> &layouts) {
        init(dev, info, layouts);
    }
    PipelineLayout(const Device &dev, VkPipelineLayoutCreateInfo &info) {
        init(dev, info);
    }
    PipelineLayout(const Device &dev, const std::vector<const DescriptorSetLayout *> &layouts = {},
                   const std::vector<VkPushConstantRange> &push_constant_ranges = {},
                   VkPipelineLayoutCreateFlags flags = static_cast<VkPipelineLayoutCreateFlags>(0)) {
        VkPipelineLayoutCreateInfo info = vku::InitStructHelper();
        info.flags = flags;
        info.pushConstantRangeCount = static_cast<uint32_t>(push_constant_ranges.size());
        info.pPushConstantRanges = push_constant_ranges.data();

        init(dev, info, layouts);
    }
    ~PipelineLayout() noexcept;
    void destroy() noexcept;

    // Move constructor for Visual Studio 2013
    PipelineLayout(PipelineLayout &&src) noexcept : NonDispHandle(std::move(src)){};

    PipelineLayout &operator=(PipelineLayout &&src) noexcept {
        this->~PipelineLayout();
        this->NonDispHandle::operator=(std::move(src));
        return *this;
    };

    // vCreatePipelineLayout()
    void init(const Device &dev, VkPipelineLayoutCreateInfo &info, const std::vector<const DescriptorSetLayout *> &layouts);
    void init(const Device &dev, VkPipelineLayoutCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkPipelineLayout>::SetName(VK_OBJECT_TYPE_PIPELINE_LAYOUT, name); }
};

class Sampler : public internal::NonDispHandle<VkSampler> {
  public:
    Sampler() = default;
    Sampler(const Device &dev, const VkSamplerCreateInfo &info) { init(dev, info); }
    ~Sampler() noexcept;
    void destroy() noexcept;

    // vkCreateSampler()
    void init(const Device &dev, const VkSamplerCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkSampler>::SetName(VK_OBJECT_TYPE_SAMPLER, name); }
};

class DescriptorSetLayout : public internal::NonDispHandle<VkDescriptorSetLayout> {
  public:
    DescriptorSetLayout() noexcept : NonDispHandle(){};
    DescriptorSetLayout(const Device &dev, const VkDescriptorSetLayoutCreateInfo &info) { init(dev, info); }
    DescriptorSetLayout(const Device &dev, const std::vector<VkDescriptorSetLayoutBinding> &descriptor_set_bindings = {},
                        VkDescriptorSetLayoutCreateFlags flags = 0, void *pNext = nullptr) {
        VkDescriptorSetLayoutCreateInfo info = vku::InitStructHelper(pNext);
        info.flags = flags;
        info.bindingCount = static_cast<uint32_t>(descriptor_set_bindings.size());
        info.pBindings = descriptor_set_bindings.data();
        init(dev, info);
    }

    ~DescriptorSetLayout() noexcept;
    void destroy() noexcept;

    // Move constructor for Visual Studio 2013
    DescriptorSetLayout(DescriptorSetLayout &&src) noexcept : NonDispHandle(std::move(src)){};

    DescriptorSetLayout &operator=(DescriptorSetLayout &&src) noexcept {
        this->~DescriptorSetLayout();
        this->NonDispHandle::operator=(std::move(src));
        return *this;
    }

    // vkCreateDescriptorSetLayout()
    void init(const Device &dev, const VkDescriptorSetLayoutCreateInfo &info);
};

class DescriptorPool : public internal::NonDispHandle<VkDescriptorPool> {
  public:
    DescriptorPool() = default;
    DescriptorPool(const Device &dev, const VkDescriptorPoolCreateInfo &info) { init(dev, info); }
    ~DescriptorPool() noexcept;
    void destroy() noexcept;

    // vkCreateDescriptorPool()
    void init(const Device &dev, const VkDescriptorPoolCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkDescriptorPool>::SetName(VK_OBJECT_TYPE_DESCRIPTOR_POOL, name); }

    // vkResetDescriptorPool()
    void reset();

    // vkFreeDescriptorSet()
    void setDynamicUsage(bool isDynamic) { dynamic_usage_ = isDynamic; }
    bool getDynamicUsage() { return dynamic_usage_; }

    // vkAllocateDescriptorSets()
    std::vector<DescriptorSet *> alloc_sets(const Device &dev, const std::vector<const DescriptorSetLayout *> &layouts);
    std::vector<DescriptorSet *> alloc_sets(const Device &dev, const DescriptorSetLayout &layout, uint32_t count);
    DescriptorSet *alloc_sets(const Device &dev, const DescriptorSetLayout &layout);

    template <typename PoolSizes>
    static VkDescriptorPoolCreateInfo create_info(VkDescriptorPoolCreateFlags flags, uint32_t max_sets,
                                                  const PoolSizes &pool_sizes);

  private:
    // Track whether this pool's usage is VK_DESCRIPTOR_POOL_USAGE_DYNAMIC
    bool dynamic_usage_;
};

template <typename PoolSizes>
inline VkDescriptorPoolCreateInfo DescriptorPool::create_info(VkDescriptorPoolCreateFlags flags, uint32_t max_sets,
                                                              const PoolSizes &pool_sizes) {
    VkDescriptorPoolCreateInfo info = vku::InitStructHelper();
    info.flags = flags;
    info.maxSets = max_sets;
    info.poolSizeCount = pool_sizes.size();
    info.pPoolSizes = (info.poolSizeCount) ? pool_sizes.data() : nullptr;
    return info;
}

class DescriptorSet : public internal::NonDispHandle<VkDescriptorSet> {
  public:
    ~DescriptorSet() noexcept;
    void destroy() noexcept;

    explicit DescriptorSet() : NonDispHandle() {}
    explicit DescriptorSet(const Device &dev, DescriptorPool *pool, VkDescriptorSet set) : NonDispHandle(dev.handle(), set) {
        containing_pool_ = pool;
    }
    void SetName(const char *name) { NonDispHandle<VkDescriptorSet>::SetName(VK_OBJECT_TYPE_DESCRIPTOR_SET, name); }

  private:
    DescriptorPool *containing_pool_;
};

class CommandPool : public internal::NonDispHandle<VkCommandPool> {
  public:
    ~CommandPool() noexcept;
    void destroy() noexcept;

    explicit CommandPool() : NonDispHandle() {}
    explicit CommandPool(const Device &dev, const VkCommandPoolCreateInfo &info) { Init(dev, info); }
    explicit CommandPool(const Device &dev, uint32_t queue_family_index, VkCommandPoolCreateFlags flags = 0);
    void Init(const Device &dev, const VkCommandPoolCreateInfo &info);
    void Init(const Device &dev, uint32_t queue_family_index, VkCommandPoolCreateFlags flags = 0);
    void SetName(const char *name) { NonDispHandle<VkCommandPool>::SetName(VK_OBJECT_TYPE_COMMAND_POOL, name); }
};

class CommandBuffer : public internal::Handle<VkCommandBuffer> {
  public:
    ~CommandBuffer() noexcept;
    void destroy() noexcept;

    explicit CommandBuffer() : Handle() {}
    explicit CommandBuffer(const Device &dev, const VkCommandBufferAllocateInfo &info) { init(dev, info); }
    explicit CommandBuffer(const Device &dev, const CommandPool &pool,
                           VkCommandBufferLevel level = VK_COMMAND_BUFFER_LEVEL_PRIMARY) {
        Init(dev, pool, level);
    }
    CommandBuffer(CommandBuffer &&rhs) noexcept : Handle(std::move(rhs)) {
        dev_handle_ = rhs.dev_handle_;
        rhs.dev_handle_ = VK_NULL_HANDLE;
        cmd_pool_ = rhs.cmd_pool_;
        rhs.cmd_pool_ = VK_NULL_HANDLE;
    }

    // vkAllocateCommandBuffers()
    void init(const Device &dev, const VkCommandBufferAllocateInfo &info);
    void Init(const Device &dev, const CommandPool &pool, VkCommandBufferLevel level = VK_COMMAND_BUFFER_LEVEL_PRIMARY);
    void SetName(const Device &device, const char *name) {
        Handle<VkCommandBuffer>::SetName(device.handle(), VK_OBJECT_TYPE_COMMAND_BUFFER, name);
    }

    // vkBeginCommandBuffer()
    void begin(const VkCommandBufferBeginInfo *info);
    void begin(VkCommandBufferUsageFlags flags);
    void begin() { begin(0u); }

    // vkEndCommandBuffer()
    // vkResetCommandBuffer()
    void end();
    void reset(VkCommandBufferResetFlags flags);
    void reset() { reset(VK_COMMAND_BUFFER_RESET_RELEASE_RESOURCES_BIT); }

    static VkCommandBufferAllocateInfo create_info(VkCommandPool const &pool);

    void BeginRenderPass(const VkRenderPassBeginInfo &info, VkSubpassContents contents = VK_SUBPASS_CONTENTS_INLINE);
    void BeginRenderPass(VkRenderPass rp, VkFramebuffer fb, uint32_t render_area_width = 1, uint32_t render_area_height = 1,
                         uint32_t clear_count = 0, VkClearValue *clear_values = nullptr);
    void NextSubpass(VkSubpassContents contents = VK_SUBPASS_CONTENTS_INLINE);
    void EndRenderPass();
    void BeginRendering(const VkRenderingInfoKHR &renderingInfo);
    void BeginRenderingColor(const VkImageView imageView, VkRect2D render_area);
    void EndRendering();

    void BindVertFragShader(const vkt::Shader &vert_shader, const vkt::Shader &frag_shader);
    void BindCompShader(const vkt::Shader &comp_shader);

    void BeginVideoCoding(const VkVideoBeginCodingInfoKHR &beginInfo);
    void ControlVideoCoding(const VkVideoCodingControlInfoKHR &controlInfo);
    void DecodeVideo(const VkVideoDecodeInfoKHR &decodeInfo);
    void EncodeVideo(const VkVideoEncodeInfoKHR &encodeInfo);
    void EndVideoCoding(const VkVideoEndCodingInfoKHR &endInfo);

    void SetEvent(Event &event, VkPipelineStageFlags stageMask) { event.cmd_set(*this, stageMask); }
    void ResetEvent(Event &event, VkPipelineStageFlags stageMask) { event.cmd_reset(*this, stageMask); }
    void WaitEvents(uint32_t eventCount, const VkEvent *pEvents, VkPipelineStageFlags srcStageMask,
                    VkPipelineStageFlags dstStageMask, uint32_t memoryBarrierCount, const VkMemoryBarrier *pMemoryBarriers,
                    uint32_t bufferMemoryBarrierCount, const VkBufferMemoryBarrier *pBufferMemoryBarriers,
                    uint32_t imageMemoryBarrierCount, const VkImageMemoryBarrier *pImageMemoryBarriers) {
        vk::CmdWaitEvents(handle(), eventCount, pEvents, srcStageMask, dstStageMask, memoryBarrierCount, pMemoryBarriers,
                          bufferMemoryBarrierCount, pBufferMemoryBarriers, imageMemoryBarrierCount, pImageMemoryBarriers);
    }

    void Copy(const Buffer &src, const Buffer &dst);
    void ExecuteCommands(const CommandBuffer &secondary);

  private:
    VkDevice dev_handle_;
    VkCommandPool cmd_pool_;
};

// shortcut to vkt::CommandBuffer{}. Useful in submit helpers: queue->submit(vkt::no_cmd, ...)
inline const CommandBuffer no_cmd;

class RenderPass : public internal::NonDispHandle<VkRenderPass> {
  public:
    RenderPass() = default;
    // vkCreateRenderPass()
    RenderPass(const Device &dev, const VkRenderPassCreateInfo &info) { init(dev, info); }
    // vkCreateRenderPass2()
    RenderPass(const Device &dev, const VkRenderPassCreateInfo2 &info) { init(dev, info); }
    ~RenderPass() noexcept;
    void destroy() noexcept;

    // vkCreateRenderPass()
    void init(const Device &dev, const VkRenderPassCreateInfo &info);
    // vkCreateRenderPass2()
    void init(const Device &dev, const VkRenderPassCreateInfo2 &info);
    void SetName(const char *name) { NonDispHandle<VkRenderPass>::SetName(VK_OBJECT_TYPE_RENDER_PASS, name); }
};


class Framebuffer : public internal::NonDispHandle<VkFramebuffer> {
  public:
    Framebuffer() = default;
    Framebuffer(const Device &dev, const VkFramebufferCreateInfo &info) { init(dev, info); }
    // The most common case, anything outside of this should create there own VkFramebufferCreateInfo
    Framebuffer(const Device &dev, VkRenderPass rp, uint32_t attchment_count, const VkImageView *attchments, uint32_t width = 32,
                uint32_t height = 32) {
        VkFramebufferCreateInfo info = vku::InitStructHelper();
        info.renderPass = rp;
        info.attachmentCount = attchment_count;
        info.pAttachments = attchments;
        info.width = width;
        info.height = height;
        info.layers = 1;
        init(dev, info);
    }
    ~Framebuffer() noexcept;
    void destroy() noexcept;

    // vkCreateFramebuffer()
    void init(const Device &dev, const VkFramebufferCreateInfo &info);
    void SetName(const char *name) { NonDispHandle<VkFramebuffer>::SetName(VK_OBJECT_TYPE_FRAMEBUFFER, name); }
};

class SamplerYcbcrConversion : public internal::NonDispHandle<VkSamplerYcbcrConversion> {
  public:
    SamplerYcbcrConversion() = default;
    SamplerYcbcrConversion(const Device &dev, VkFormat format) {
        init(dev, DefaultConversionInfo(format));
    }
    SamplerYcbcrConversion(const Device &dev, const VkSamplerYcbcrConversionCreateInfo &info) {
        init(dev, info);
    }
    ~SamplerYcbcrConversion() noexcept;
    void destroy() noexcept;

    void init(const Device &dev, const VkSamplerYcbcrConversionCreateInfo &info);
    void SetName(const char *name) {
        NonDispHandle<VkSamplerYcbcrConversion>::SetName(VK_OBJECT_TYPE_SAMPLER_YCBCR_CONVERSION, name);
    }
    VkSamplerYcbcrConversionInfo ConversionInfo();

    static VkSamplerYcbcrConversionCreateInfo DefaultConversionInfo(VkFormat format);
};

inline VkBufferCreateInfo Buffer::create_info(VkDeviceSize size, VkFlags usage, const std::vector<uint32_t> *queue_families,
                                              void *create_info_pnext) {
    VkBufferCreateInfo info = vku::InitStructHelper(create_info_pnext);
    info.size = size;
    info.usage = usage;

    if (queue_families && queue_families->size() > 1) {
        info.sharingMode = VK_SHARING_MODE_CONCURRENT;
        info.queueFamilyIndexCount = static_cast<uint32_t>(queue_families->size());
        info.pQueueFamilyIndices = queue_families->data();
    }

    return info;
}

inline VkFenceCreateInfo Fence::create_info(VkFenceCreateFlags flags) {
    VkFenceCreateInfo info = vku::InitStructHelper();
    info.flags = flags;
    return info;
}

inline VkFenceCreateInfo Fence::create_info() {
    VkFenceCreateInfo info = vku::InitStructHelper();
    return info;
}

inline VkEventCreateInfo Event::create_info(VkFlags flags) {
    VkEventCreateInfo info = vku::InitStructHelper();
    info.flags = flags;
    return info;
}

inline VkQueryPoolCreateInfo QueryPool::create_info(VkQueryType type, uint32_t slot_count) {
    VkQueryPoolCreateInfo info = vku::InitStructHelper();
    info.queryType = type;
    info.queryCount = slot_count;
    return info;
}

inline VkImageCreateInfo Image::create_info() {
    VkImageCreateInfo info = vku::InitStructHelper();
    info.extent.width = 1;
    info.extent.height = 1;
    info.extent.depth = 1;
    info.mipLevels = 1;
    info.arrayLayers = 1;
    info.samples = VK_SAMPLE_COUNT_1_BIT;
    return info;
}

inline VkImageSubresource Image::subresource(VkImageAspectFlags aspect, uint32_t mip_level, uint32_t array_layer) {
    VkImageSubresource subres = {};
    if (aspect == 0) {
        assert(false && "Invalid VkImageAspectFlags");
    }
    subres.aspectMask = aspect;
    subres.mipLevel = mip_level;
    subres.arrayLayer = array_layer;
    return subres;
}

inline VkImageSubresource Image::subresource(const VkImageSubresourceRange &range, uint32_t mip_level, uint32_t array_layer) {
    return subresource(range.aspectMask, range.baseMipLevel + mip_level, range.baseArrayLayer + array_layer);
}

inline VkImageSubresourceLayers Image::subresource(VkImageAspectFlags aspect, uint32_t mip_level, uint32_t array_layer,
                                                   uint32_t array_size) {
    VkImageSubresourceLayers subres = {};
    switch (aspect) {
        case VK_IMAGE_ASPECT_COLOR_BIT:
        case VK_IMAGE_ASPECT_DEPTH_BIT:
        case VK_IMAGE_ASPECT_STENCIL_BIT:
        case VK_IMAGE_ASPECT_DEPTH_BIT | VK_IMAGE_ASPECT_STENCIL_BIT:
            /* valid */
            break;
        default:
            assert(false && "Invalid VkImageAspectFlags");
    }
    subres.aspectMask = aspect;
    subres.mipLevel = mip_level;
    subres.baseArrayLayer = array_layer;
    subres.layerCount = array_size;
    return subres;
}

inline VkImageSubresourceLayers Image::subresource(const VkImageSubresourceRange &range, uint32_t mip_level, uint32_t array_layer,
                                                   uint32_t array_size) {
    return subresource(range.aspectMask, range.baseMipLevel + mip_level, range.baseArrayLayer + array_layer, array_size);
}

inline VkImageSubresourceRange Image::subresource_range(VkImageAspectFlags aspect_mask, uint32_t base_mip_level,
                                                        uint32_t mip_levels, uint32_t base_array_layer, uint32_t num_layers) {
    VkImageSubresourceRange range = {};
    if (aspect_mask == 0) {
        assert(false && "Invalid VkImageAspectFlags");
    }
    range.aspectMask = aspect_mask;
    range.baseMipLevel = base_mip_level;
    range.levelCount = mip_levels;
    range.baseArrayLayer = base_array_layer;
    range.layerCount = num_layers;
    return range;
}

inline VkImageSubresourceRange Image::subresource_range(const VkImageCreateInfo &info, VkImageAspectFlags aspect_mask) {
    return subresource_range(aspect_mask, 0, info.mipLevels, 0, info.arrayLayers);
}

inline VkImageSubresourceRange Image::subresource_range(const VkImageSubresource &subres) {
    return subresource_range(subres.aspectMask, subres.mipLevel, 1, subres.arrayLayer, 1);
}

inline VkShaderModuleCreateInfo ShaderModule::create_info(size_t code_size, const uint32_t *code, VkFlags flags) {
    VkShaderModuleCreateInfo info = vku::InitStructHelper();
    info.codeSize = code_size;
    info.pCode = code;
    info.flags = flags;
    return info;
}

inline VkWriteDescriptorSet Device::write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                         VkDescriptorType type, uint32_t count,
                                                         const VkDescriptorImageInfo *image_info) {
    VkWriteDescriptorSet write = vku::InitStructHelper();
    write.dstSet = set.handle();
    write.dstBinding = binding;
    write.dstArrayElement = array_element;
    write.descriptorCount = count;
    write.descriptorType = type;
    write.pImageInfo = image_info;
    return write;
}

inline VkWriteDescriptorSet Device::write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                         VkDescriptorType type, uint32_t count,
                                                         const VkDescriptorBufferInfo *buffer_info) {
    VkWriteDescriptorSet write = vku::InitStructHelper();
    write.dstSet = set.handle();
    write.dstBinding = binding;
    write.dstArrayElement = array_element;
    write.descriptorCount = count;
    write.descriptorType = type;
    write.pBufferInfo = buffer_info;
    return write;
}

inline VkWriteDescriptorSet Device::write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                         VkDescriptorType type, uint32_t count, const VkBufferView *buffer_views) {
    VkWriteDescriptorSet write = vku::InitStructHelper();
    write.dstSet = set.handle();
    write.dstBinding = binding;
    write.dstArrayElement = array_element;
    write.descriptorCount = count;
    write.descriptorType = type;
    write.pTexelBufferView = buffer_views;
    return write;
}

inline VkWriteDescriptorSet Device::write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                         VkDescriptorType type,
                                                         const std::vector<VkDescriptorImageInfo> &image_info) {
    return write_descriptor_set(set, binding, array_element, type, static_cast<uint32_t>(image_info.size()), &image_info[0]);
}

inline VkWriteDescriptorSet Device::write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                         VkDescriptorType type,
                                                         const std::vector<VkDescriptorBufferInfo> &buffer_info) {
    return write_descriptor_set(set, binding, array_element, type, static_cast<uint32_t>(buffer_info.size()), &buffer_info[0]);
}

inline VkWriteDescriptorSet Device::write_descriptor_set(const DescriptorSet &set, uint32_t binding, uint32_t array_element,
                                                         VkDescriptorType type, const std::vector<VkBufferView> &buffer_views) {
    return write_descriptor_set(set, binding, array_element, type, static_cast<uint32_t>(buffer_views.size()), &buffer_views[0]);
}

inline VkCopyDescriptorSet Device::copy_descriptor_set(const DescriptorSet &src_set, uint32_t src_binding,
                                                       uint32_t src_array_element, const DescriptorSet &dst_set,
                                                       uint32_t dst_binding, uint32_t dst_array_element, uint32_t count) {
    VkCopyDescriptorSet copy = vku::InitStructHelper();
    copy.srcSet = src_set.handle();
    copy.srcBinding = src_binding;
    copy.srcArrayElement = src_array_element;
    copy.dstSet = dst_set.handle();
    copy.dstBinding = dst_binding;
    copy.dstArrayElement = dst_array_element;
    copy.descriptorCount = count;

    return copy;
}

}  // namespace vkt
