/*
 * Decompiled with CFR 0.152.
 */
package me.cortex.voxy.client;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.function.Consumer;
import me.cortex.voxy.common.Logger;
import org.lwjgl.system.APIUtil;
import org.lwjgl.system.JNI;
import org.lwjgl.system.MemoryStack;
import org.lwjgl.system.MemoryUtil;
import org.lwjgl.system.SharedLibrary;
import org.lwjgl.system.windows.GDI32;
import org.lwjgl.system.windows.Kernel32;

public class GPUSelectorWindows2 {
    private static final long D3DKMTSetProperties = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)GDI32.getLibrary(), (String)"D3DKMTSetProperties");
    private static final long D3DKMTEnumAdapters2 = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)GDI32.getLibrary(), (String)"D3DKMTEnumAdapters2");
    private static final long D3DKMTCloseAdapter = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)GDI32.getLibrary(), (String)"D3DKMTCloseAdapter");
    private static final long D3DKMTQueryAdapterInfo = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)GDI32.getLibrary(), (String)"D3DKMTQueryAdapterInfo");
    private static final int[] HDC_STUB = new int[]{72, 131, 193, 12, 72, 184, 255, 255, 255, 255, 255, 255, 255, 31, 72, 137, 1, 72, 184, 255, 255, 255, 255, 255, 255, 255, 47, 81, 255, 208, 89, 139, 65, 8, 137, 65, 252, 72, 49, 192, 137, 65, 8, 195};
    private static final long D3DKMTOpenAdapterFromLuid = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)GDI32.getLibrary(), (String)"D3DKMTOpenAdapterFromLuid");
    private static final long D3DKMTOpenAdapterFromHdc = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)GDI32.getLibrary(), (String)"D3DKMTOpenAdapterFromHdc");
    private static final long VirtualProtect = APIUtil.apiGetFunctionAddressOptional((SharedLibrary)Kernel32.getLibrary(), (String)"VirtualProtect");

    private static int setPCIProperties(int type, int vendor, int device, int subSys) {
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(16).order(ByteOrder.nativeOrder());
            buff.putInt(0, vendor);
            buff.putInt(4, device);
            buff.putInt(8, subSys);
            buff.putInt(12, 0);
            int n = GPUSelectorWindows2.setProperties(type, buff);
            return n;
        }
    }

    private static int setProperties(int type, ByteBuffer payload) {
        if (D3DKMTSetProperties == 0L) {
            return -1;
        }
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(24).order(ByteOrder.nativeOrder());
            buff.putInt(0, type);
            buff.putInt(4, payload.remaining());
            buff.putLong(16, MemoryUtil.memAddress((ByteBuffer)payload));
            int n = JNI.callPI((long)MemoryUtil.memAddress((ByteBuffer)buff), (long)D3DKMTSetProperties);
            return n;
        }
    }

    private static int query(int handle, int type, ByteBuffer payload) {
        if (D3DKMTQueryAdapterInfo == 0L) {
            return -1;
        }
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(20).order(ByteOrder.nativeOrder());
            buff.putInt(0, handle);
            buff.putInt(4, type);
            buff.putLong(8, MemoryUtil.memAddress((ByteBuffer)payload));
            buff.putInt(16, payload.remaining());
            int n = JNI.callPI((long)MemoryUtil.memAddress((ByteBuffer)buff), (long)D3DKMTQueryAdapterInfo);
            return n;
        }
    }

    private static int closeHandle(int handle) {
        if (D3DKMTCloseAdapter == 0L) {
            return -1;
        }
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(4).order(ByteOrder.nativeOrder());
            buff.putInt(0, handle);
            int n = JNI.callPI((long)MemoryUtil.memAddress((ByteBuffer)buff), (long)D3DKMTCloseAdapter);
            return n;
        }
    }

    private static int queryAdapterType(int handle, int[] out) {
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(4).order(ByteOrder.nativeOrder());
            int ret = GPUSelectorWindows2.query(handle, 15, buff);
            if (ret < 0) {
                int n = ret;
                return n;
            }
            out[0] = buff.getInt(0);
        }
        return 0;
    }

    private static int queryAdapterIcd(int handle, String[] out) {
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(528).order(ByteOrder.nativeOrder());
            int ret = GPUSelectorWindows2.query(handle, 2, buff);
            if (ret < 0) {
                int n = ret;
                return n;
            }
            int len = Math.min(MemoryUtil.memLengthNT2((ByteBuffer)buff), 520);
            out[0] = MemoryUtil.memUTF16((ByteBuffer)buff.limit(len));
        }
        return 0;
    }

    private static int queryPCIAddress(int handle, int index, PCIDeviceId[] deviceOut) {
        int ret = 0;
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer buff = stack.calloc(28).order(ByteOrder.nativeOrder());
            buff.putInt(0, index);
            ret = GPUSelectorWindows2.query(handle, 31, buff);
            if (ret < 0) {
                int n = ret;
                return n;
            }
            deviceOut[0] = new PCIDeviceId(buff.getInt(4), buff.getInt(8), buff.getInt(12), buff.getInt(16), buff.getInt(20), buff.getInt(24));
            int n = 0;
            return n;
        }
    }

    private static int enumAdapters(Consumer<AdapterInfo> consumer) {
        if (D3DKMTEnumAdapters2 == 0L) {
            return -1;
        }
        int ret = 0;
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer query = stack.calloc(16).order(ByteOrder.nativeOrder());
            ret = JNI.callPI((long)MemoryUtil.memAddress((ByteBuffer)query), (long)D3DKMTEnumAdapters2);
            if (ret < 0) {
                int n = ret;
                return n;
            }
            int adapterCount = query.getInt(0);
            ByteBuffer adapterList = stack.calloc(20 * adapterCount).order(ByteOrder.nativeOrder());
            query.putLong(8, MemoryUtil.memAddress((ByteBuffer)adapterList));
            ret = JNI.callPI((long)MemoryUtil.memAddress((ByteBuffer)query), (long)D3DKMTEnumAdapters2);
            if (ret < 0) {
                int n = ret;
                return n;
            }
            adapterCount = query.getInt(0);
            for (int adapterIndex = 0; adapterIndex < adapterCount; ++adapterIndex) {
                ByteBuffer adapter = adapterList.slice(adapterIndex * 20, 20).order(ByteOrder.nativeOrder());
                int handle = adapter.getInt(0);
                long luid = adapter.getLong(4);
                int[] type = new int[1];
                ret = GPUSelectorWindows2.queryAdapterType(handle, type);
                if (ret < 0) {
                    Logger.error("Query type error: " + ret);
                    if (GPUSelectorWindows2.closeHandle(handle) >= 0) continue;
                    throw new IllegalStateException();
                }
                String[] icd = new String[1];
                ret = GPUSelectorWindows2.queryAdapterIcd(handle, icd);
                if (ret < 0) {
                    Logger.error("Query icd error: " + ret);
                    if (GPUSelectorWindows2.closeHandle(handle) >= 0) continue;
                    throw new IllegalStateException();
                }
                PCIDeviceId[] out = new PCIDeviceId[1];
                ret = GPUSelectorWindows2.queryPCIAddress(handle, 0, out);
                if (ret < 0) {
                    Logger.error("Query pci error: " + ret);
                    if (GPUSelectorWindows2.closeHandle(handle) >= 0) continue;
                    throw new IllegalStateException();
                }
                int subSys = out[0].subSystem << 16 | out[0].subVendor;
                consumer.accept(new AdapterInfo(icd[0], type[0], luid, out[0].vendor, out[0].device, subSys));
                if (GPUSelectorWindows2.closeHandle(handle) >= 0) continue;
                throw new IllegalStateException();
            }
        }
        return 0;
    }

    private static void insertLong(long l, byte[] out, int offset) {
        for (int i = 0; i < 8; ++i) {
            out[i + offset] = (byte)(l & 0xFFL);
            l >>= 8;
        }
    }

    private static byte[] createFinishedHDCStub(long luid) {
        byte[] stub = new byte[HDC_STUB.length];
        for (int i = 0; i < stub.length; ++i) {
            stub[i] = (byte)HDC_STUB[i];
        }
        GPUSelectorWindows2.insertLong(luid, stub, 6);
        GPUSelectorWindows2.insertLong(D3DKMTOpenAdapterFromLuid, stub, 19);
        return stub;
    }

    private static byte[] toByteArray(int ... array) {
        byte[] res = new byte[array.length];
        for (int i = 0; i < array.length; ++i) {
            res[i] = (byte)array[i];
        }
        return res;
    }

    private static void VirtualProtect(long addr, long size) {
        try (MemoryStack stack = MemoryStack.stackPush();){
            ByteBuffer oldProtection = stack.calloc(4);
            JNI.callPPPPI((long)addr, (long)size, (long)64L, (long)MemoryUtil.memAddress((ByteBuffer)oldProtection), (long)VirtualProtect);
        }
    }

    private static void memcpy(long ptr, byte[] data) {
        for (int i = 0; i < data.length; ++i) {
            MemoryUtil.memPutByte((long)(ptr + (long)i), (byte)data[i]);
        }
    }

    private static void installHDCStub(long adapterLuid) {
        if (D3DKMTOpenAdapterFromHdc == 0L || VirtualProtect == 0L || D3DKMTOpenAdapterFromLuid == 0L) {
            return;
        }
        Logger.info("AdapterLuid callback at: " + Long.toHexString(D3DKMTOpenAdapterFromLuid));
        byte[] stub = GPUSelectorWindows2.createFinishedHDCStub(adapterLuid);
        GPUSelectorWindows2.VirtualProtect(D3DKMTOpenAdapterFromHdc, stub.length);
        GPUSelectorWindows2.memcpy(D3DKMTOpenAdapterFromHdc, stub);
    }

    private static byte[] createIntelStub(long origA, long origB, long jmpA, long jmpB) {
        byte[] stub = GPUSelectorWindows2.toByteArray(254, 13, 99, 0, 0, 0, 128, 61, 92, 0, 0, 0, 2, 117, 7, 72, 49, 192, 72, 247, 208, 195, 72, 184, 0, 0, 0, 0, 0, 0, 0, 1, 72, 139, 13, 67, 0, 0, 0, 72, 137, 8, 72, 139, 13, 65, 0, 0, 0, 72, 137, 72, 8, 128, 61, 45, 0, 0, 0, 0, 116, 29, 80, 255, 208, 88, 72, 139, 13, 49, 0, 0, 0, 72, 137, 8, 72, 139, 13, 47, 0, 0, 0, 72, 137, 72, 8, 72, 49, 192, 195, 72, 139, 65, 8, 199, 0, 0, 0, 0, 0, 72, 49, 192, 195, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
        GPUSelectorWindows2.insertLong(D3DKMTQueryAdapterInfo, stub, 24);
        stub[105] = 3;
        GPUSelectorWindows2.insertLong(origA, stub, 106);
        GPUSelectorWindows2.insertLong(origB, stub, 114);
        GPUSelectorWindows2.insertLong(jmpA, stub, 122);
        GPUSelectorWindows2.insertLong(jmpB, stub, 130);
        return stub;
    }

    private static byte[] createSimpleStub(long origA, long origB) {
        byte[] stub = GPUSelectorWindows2.toByteArray(72, 184, 0, 0, 0, 0, 0, 0, 0, 0, 72, 139, 13, 21, 0, 0, 0, 72, 137, 8, 72, 139, 13, 19, 0, 0, 0, 72, 137, 72, 8, 72, 49, 192, 72, 247, 208, 195, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
        GPUSelectorWindows2.insertLong(D3DKMTQueryAdapterInfo, stub, 2);
        GPUSelectorWindows2.insertLong(origA, stub, 38);
        GPUSelectorWindows2.insertLong(origB, stub, 46);
        return stub;
    }

    private static void installQueryStub(boolean installIntelBypass) {
        if (D3DKMTQueryAdapterInfo == 0L || VirtualProtect == 0L) {
            return;
        }
        GPUSelectorWindows2.VirtualProtect(D3DKMTQueryAdapterInfo, 16L);
        int MAX_STUB_SIZE = 1024;
        long stubPtr = MemoryUtil.nmemAlloc((long)MAX_STUB_SIZE);
        GPUSelectorWindows2.VirtualProtect(stubPtr, MAX_STUB_SIZE);
        Logger.info("Do stub at: " + Long.toHexString(stubPtr));
        long origA = MemoryUtil.memGetLong((long)D3DKMTQueryAdapterInfo);
        long origB = MemoryUtil.memGetLong((long)(D3DKMTQueryAdapterInfo + 8L));
        byte[] jmpStub = new byte[]{72, -72, 0, 0, 0, 0, 0, 0, 0, 0, -1, -32};
        GPUSelectorWindows2.insertLong(stubPtr, jmpStub, 2);
        GPUSelectorWindows2.memcpy(D3DKMTQueryAdapterInfo, jmpStub);
        Logger.info("D3DKMTQueryAdapterInfo at: " + Long.toHexString(D3DKMTQueryAdapterInfo));
        long jmpA = MemoryUtil.memGetLong((long)D3DKMTQueryAdapterInfo);
        long jmpB = MemoryUtil.memGetLong((long)(D3DKMTQueryAdapterInfo + 8L));
        byte[] stub = installIntelBypass ? GPUSelectorWindows2.createIntelStub(origA, origB, jmpA, jmpB) : GPUSelectorWindows2.createSimpleStub(origA, origB);
        GPUSelectorWindows2.memcpy(stubPtr, stub);
        Logger.info("QueryAdapterInfo stubs installed");
    }

    public static void doSelector(int index) {
        ArrayList adapters = new ArrayList();
        if (GPUSelectorWindows2.enumAdapters(adapter -> {
            if ((adapter.type & 5) == 1) {
                adapters.add(adapter);
            }
        }) < 0) {
            return;
        }
        for (AdapterInfo adapter2 : adapters) {
            Logger.error(adapter2.toString());
        }
        AdapterInfo adapter3 = (AdapterInfo)adapters.get(index);
        GPUSelectorWindows2.installHDCStub(adapter3.luid);
        GPUSelectorWindows2.installQueryStub(adapter3.icdPath.matches("\\\\ig[a-z0-9]+icd(32|64)\\.dll$"));
        GPUSelectorWindows2.setPCIProperties(1, adapter3.vendor, adapter3.device, adapter3.subSystem);
        GPUSelectorWindows2.setPCIProperties(2, adapter3.vendor, adapter3.device, adapter3.subSystem);
    }

    private record PCIDeviceId(int vendor, int device, int subVendor, int subSystem, int revision, int busType) {
    }

    private record AdapterInfo(String icdPath, int type, long luid, int vendor, int device, int subSystem) {
        @Override
        public String toString() {
            String LUID = Integer.toHexString((int)(this.luid >>> 32 & 0xFFFFFFFFL)) + "-" + Integer.toHexString((int)(this.luid & 0xFFFFFFFFL));
            return "{type=%s, luid=%s, vendor=%s, device=%s, subSys=%s, icd=\"%s\"}".formatted(Integer.toString(this.type), LUID, Integer.toHexString(this.vendor), Integer.toHexString(this.device), Integer.toHexString(this.subSystem), this.icdPath);
        }
    }
}

