2021-09-09 20:42:29 -04:00

305 lines
9.0 KiB
Plaintext

#pragma kernel BitonicSort128 BITONIC_SORT=BitonicSort128 ELEMENTS_COUNT=128 ITERATIONS_COUNT=1 FINAL_PASS=1
#pragma kernel BitonicSort1024 BITONIC_SORT=BitonicSort1024 ELEMENTS_COUNT=1024 ITERATIONS_COUNT=2 FINAL_PASS=1
#pragma kernel BitonicSort4096 BITONIC_SORT=BitonicSort4096 ELEMENTS_COUNT=4096 ITERATIONS_COUNT=2 FINAL_PASS=1
#pragma kernel BitonicPrePass BITONIC_SORT=BitonicPrePass ELEMENTS_COUNT=4096 ITERATIONS_COUNT=2 FINAL_PASS=0
#pragma kernel MergePass MERGE_PASS=MergePass FINAL_PASS=0
#pragma kernel MergeFinalPass MERGE_PASS=MergeFinalPass FINAL_PASS=1
#pragma multi_compile __ VFX_SORT_USE_ELEMENT_COUNT_BUFFER
#include "HLSLSupport.cginc"
#define DECREASING_ORDER 1
#define DEBUG_NO_INFINITE_LOOP 0
#ifndef ELEMENTS_COUNT
#define ELEMENTS_COUNT 1024
#endif
#ifndef ITERATIONS_COUNT
#define ITERATIONS_COUNT 1
#endif
// 1 to use the alternative representation of the bitonic network:
// No comparison flipping but dedicated first sub pass.
// see https://en.wikipedia.org/wiki/Bitonic_sorter
#if defined(SHADER_API_METAL)
//Workaround BitonicSort128 doesn't behave as expected, actual cause isn't identified yet.
#define USE_ALTERNATE_BITONIC_NETWORK 1
#else
#define USE_ALTERNATE_BITONIC_NETWORK 0
#endif
#define ELEMENTS_PER_THREAD (2 * ITERATIONS_COUNT)
#define BITONIC_THREADS_COUNT ELEMENTS_COUNT / ELEMENTS_PER_THREAD
#define MERGE_THREADS_COUNT ELEMENTS_COUNT
#define HALF_SCRATCH_SIZE ELEMENTS_COUNT
#define LDS_VALUES_OFFSET HALF_SCRATCH_SIZE
#define SCRATCH_SIZE (HALF_SCRATCH_SIZE << 1)
#define ITERATION_INDEX(id,it) ((it) * BITONIC_THREADS_COUNT + (id))
#define DST_INDEX(groupId,threadId) (groupId * BITONIC_THREADS_COUNT * ELEMENTS_PER_THREAD + threadId)
#if DECREASING_ORDER
#define REJECTED_VALUE asfloat(0xff7fffff) // -MAX_FLOAT
#else
#define REJECTED_VALUE asfloat(0x7f7fffff) // MAX_FLOAT
#endif
#pragma warning(disable : 3557) // disable warning for auto unrolling of single iteration loop
struct KVP
{
float key;
uint value;
};
StructuredBuffer<KVP> inputSequence;
#if FINAL_PASS
RWStructuredBuffer<uint> sortedSequence;
#else
RWStructuredBuffer<KVP> sortedSequence;
#endif
ByteAddressBuffer deadElementCount;
#ifdef VFX_SORT_USE_ELEMENT_COUNT_BUFFER
#define elementCountBuffer deadElementCount
#endif
// Layout of the scratch memory is as follows:
// First all keys
// Then all values
groupshared float scratch[SCRATCH_SIZE];
CBUFFER_START(SortConstants)
uint elementCount;
uint elementCountOffset;
CBUFFER_END
uint GetElementCount()
{
#ifdef VFX_SORT_USE_ELEMENT_COUNT_BUFFER
return elementCountBuffer.Load(elementCountOffset);
#else
return elementCount - deadElementCount.Load(elementCountOffset);
#endif
}
bool CompareKeys(float key0, float key1)
{
#if DECREASING_ORDER
return key0 > key1;
#else
return key0 < key1;
#endif
}
uint GetLDSIndex(uint index)
{
return index;
}
void LoadFromMemory(uint ldsIndex,uint memIndex)
{
KVP kvp = { REJECTED_VALUE, 0 };
if (memIndex < GetElementCount())
kvp = inputSequence[memIndex];
uint paddedLdsIndex = GetLDSIndex(ldsIndex);
scratch[paddedLdsIndex] = kvp.key;
scratch[paddedLdsIndex + LDS_VALUES_OFFSET] = asfloat(kvp.value);
}
void StoreToMemory(uint memIndex, uint ldsIndex)
{
if (memIndex < GetElementCount())
{
uint paddedLdsIndex = GetLDSIndex(ldsIndex);
uint value = asuint(scratch[paddedLdsIndex + LDS_VALUES_OFFSET]);
#if FINAL_PASS
sortedSequence[memIndex] = value;
#else
float key = scratch[paddedLdsIndex];
KVP kvp = { key, value };
sortedSequence[memIndex] = kvp;
#endif
}
}
// Bitonic sort on small chunks of kvp of size ELEMENTS_COUNT - execute in O(log²(ELEMENTS_COUNT))
[numthreads(BITONIC_THREADS_COUNT,1,1)]
void BITONIC_SORT(uint id : SV_GroupIndex, uint3 groupId : SV_GroupID)
{
// Skip useless groups
if (groupId.x > GetElementCount() / ELEMENTS_COUNT)
return;
// Load data from memory to LDS
//[unroll]
for (uint i = 0; i < ELEMENTS_PER_THREAD; ++i)
{
uint index = BITONIC_THREADS_COUNT * i + id;
uint memIndex = DST_INDEX(groupId.x, index);
LoadFromMemory(index, memIndex);
}
GroupMemoryBarrierWithGroupSync(); // LDS Writes visible
for (uint step = 1; step < ELEMENTS_COUNT; step <<= 1) // O(log(ELEMENTS_COUNT))
for (uint subStep = step; subStep != 0; subStep >>= 1) // O(log(step))
{
[unroll]
for (uint i = 0; i < ITERATIONS_COUNT; ++i)
{
uint index = ITERATION_INDEX(id,i);
uint lsb = index & (subStep - 1);
uint index0 = (2 * index) - lsb;
uint index1 = index0 + subStep;
#if USE_ALTERNATE_BITONIC_NETWORK
if (subStep == step)
index1 += step - (2 * lsb) - 1;
#endif
float key0 = scratch[index0];
float key1 = scratch[index1];
#if USE_ALTERNATE_BITONIC_NETWORK
bool reverse = false;
#else
bool reverse = index & step;
#endif
if (CompareKeys(key1,key0) != reverse)
{
// swap keys
scratch[index0] = key1;
scratch[index1] = key0;
// swap values
float value0 = scratch[index0 + LDS_VALUES_OFFSET];
scratch[index0 + LDS_VALUES_OFFSET] = scratch[index1 + LDS_VALUES_OFFSET];
scratch[index1 + LDS_VALUES_OFFSET] = value0;
}
}
GroupMemoryBarrierWithGroupSync(); // LDS Writes visible
}
// Store sorted data from LDS to memory
//[unroll]
for (uint j = 0; j < ELEMENTS_PER_THREAD; ++j)
{
uint index = BITONIC_THREADS_COUNT * j + id; //(id / 32) * 64 * ITERATIONS_COUNT + (id & 31) + j * 32;//id * ITERATIONS_COUNT * 2 + j;
uint memIndex = DST_INDEX(groupId.x, index);
StoreToMemory(memIndex, index);
}
}
CBUFFER_START(MergePassConstants)
uint subArraySize;
uint dispatchWidth;
CBUFFER_END
float GetKeyWithCheck(uint index)
{
float result = inputSequence[index].key;
if (index >= GetElementCount()) // TODO Handle elementCount more efficiently
result = REJECTED_VALUE;
return result;
}
// Merge pass: take N sorted sub arrays as input and output N/2 sorted arrays twice bigger - execute in O(log(subArraySize))
#define NB_THREADS_PER_GROUP_MERGE_PASS 64
[numthreads(NB_THREADS_PER_GROUP_MERGE_PASS, 1, 1)]
void MERGE_PASS(uint3 groupId : SV_GroupID,
uint3 groupThreadId : SV_GroupThreadID)
{
uint id = groupThreadId.x + groupId.x * NB_THREADS_PER_GROUP_MERGE_PASS + groupId.y * dispatchWidth * NB_THREADS_PER_GROUP_MERGE_PASS;
if (id >= GetElementCount())
return;
const int arraySize = subArraySize << 1;
const int arrayStart = arraySize * (id / arraySize);
// If the current array considered is less than one half filled (due to element count), we can copy it directly as it is already sorted
if (GetElementCount() - (uint)arrayStart <= subArraySize)
{
#if FINAL_PASS
sortedSequence[id] = inputSequence[id].value;
#else
sortedSequence[id] = inputSequence[id];
#endif
return;
}
const int arrayIndex = id - arrayStart;
const int lastIndex = subArraySize - 1;
// determine initial frame of the window
int2 window = uint2(max(0, arrayIndex - (int)subArraySize), min(arrayIndex, lastIndex));
int index0, index1;
float key0, key1;
bool reverse = false;
bool done = false;
#if DEBUG_NO_INFINITE_LOOP
uint nbIter = 0;
#endif
// Binary search - O(log(subArraySize))
do
{
int windowIndex = (window.x + window.y + 1) >> 1;
int i0 = min(lastIndex, windowIndex);
int i1 = min(lastIndex, arrayIndex - windowIndex);
index0 = i0 + arrayStart;
index1 = i1 + arrayStart + subArraySize;
key0 = inputSequence[index0].key;
key1 = GetKeyWithCheck(index1);
if (i0 + i1 == arrayIndex)
{
if (i0 > 0 && CompareKeys(key1, inputSequence[max(0, index0 - 1)].key))
window.y = windowIndex - 1; // move window left
else if (i1 > 0 && CompareKeys(key0, GetKeyWithCheck(max(0, index1 - 1))))
window.x = windowIndex + (windowIndex == lastIndex ? 2 : 1); // move window right (Special handling at the right bound so that i1 can go down to 0)
else
done = true;
}
else // special case handling
{
reverse = true;
done = true;
}
#if DEBUG_NO_INFINITE_LOOP
if (++nbIter > 2048)
{
key0 = key1 = 0.5f;
break;
}
#endif
} while (!done);
// Select left or right: Must handle special case for equality
bool select0 = (key0 == key1 && (id & 1)) || (CompareKeys(key0,key1) != reverse);
uint value = inputSequence[(select0 ? index0 : index1)].value;
#if FINAL_PASS
sortedSequence[id] = value;
#else
float key = select0 ? key0 : key1;
KVP kvp = { key, value };
sortedSequence[id] = kvp;
#endif
}