#include "gzip.h"

// we build a table of uh everything. this uses a lot of stack space, 
void GzipPopulateTable(
	u8*  InputLengths,  // Lengths of symbol bits
	uptr InputCount,    // Number of symbols
	u8*  OutputLengths, // Output bit lengths
	u32* OutputSymbols, // Output symbols
	uptr OutputCount,   // Number of outputs
	uptr OutputBits     // Highest bit length
) {
	u32 NextSym[15]; /// Next symbol for a specific bit length
	for (int i = 0; i < 15; i++) NextSym[i] = -1;

	// Count the symbols in each bit length
	u32 CodeCounts[15] = {};
	for (int i = 0; i < InputCount; i++) {
		CodeCounts[InputLengths[i]]++;
		if (NextSym[InputLengths[i]] == -1) NextSym[InputLengths[i]] = i;
	}

	// Ignore null symbols
	CodeCounts[0] = 0;

	// Find lowest code for each bit length
	u32 Code = 0;
	u32 NextCode[15] = {}; // Next code to add to the table
	for (u32 Bits = 1; Bits < 15; Bits++) {
		Code = (Code + CodeCounts[Bits - 1]) << 1;
		NextCode[Bits] = Code;
	}

	// Fill in the table in a horribly inefficient manner
	Code = 0;
	for (u32 Bits = 0; Bits < 15; Bits++) {
		const uptr OutputMask = ~((uptr)-1 << Bits);
		for (int i = 0; i < CodeCounts[Bits]; i++) {
			// Take the next symbol
			Code = NextSym[Bits];

			// Set the next next symbol
			for (int j = Code + 1; j < InputCount; j++) {
				if (InputLengths[j] == Bits) {
					NextSym[Bits] = j;
					break;
				}
			}

			// reverse the bits i dont fucking know it doesn't work if i dont
			u32 Reversed = 0;
			for (int j = 0; j < Bits; j++) {
				Reversed |= ((NextCode[Bits] >> j) & 1) << (Bits - j - 1);
			}

			// this is horrible
			for (int j = 0; j < OutputCount; j++) {
				if ((j & OutputMask) == Reversed) {
					OutputLengths[j] = Bits;
					OutputSymbols[j] = Code;
				}
			}

			// Next
			NextCode[Bits]++;
			Code++;
		}
	}

	// "Good Practice" insists that i inform you that this is in fact, a return statement.
	return;
}

u64 GzipFetchBits(u8* Stream, u64* Location, uptr Length) {
	u64 Output = 0;
	for (uptr i = 0; i < Length; i++) {
		Output |= ((Stream[*Location / 8] >> (*Location % 8)) & 1) << i;
		*Location += 1;
	}
	return Output;
}

uptr GzipDecompress(c8* Input, uptr InputSize, c8* Output) {
	#define TAKE_(_TYPE) (\
		(Location += sizeof(_TYPE)), \
		*(_TYPE*)(Location - sizeof(_TYPE)) \
	)

	const u32 LengthAddends[29] = { 
        	3,   4,   5,  6,   7,   8,   9,   10,  11,  13,
        	15,  17,  19, 23,  27,  31,  35,  43,  51,  59, 
        	67,  83,  99, 115, 131, 163, 195, 227, 258
	};

	const u32 DistanceAddends[30] = {
        	1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025,
        	1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577
	};

	void* Location    = (void*)Input;
	void* End         = Location + InputSize;
	void* BinaryStart = (void*)Output;

	// each member hopefully. any bug will result in complete failure c:
	// we dont even bother making sure it's valid, we have no way to tell
	// the user there's something wrong
	while (Location < End) {
		c8 Id1 = TAKE_(c8);
		c8 Id2 = TAKE_(c8);
		
		u8 Method = TAKE_(u8);
		u8 Flags = TAKE_(u8);
		u32 LastModified = TAKE_(u32);
		u8 ExtraFlags = TAKE_(u8);
		u8 OperatingSystem = TAKE_(u8);

		// for the magic extra field
		u16   XLen = 0;
		void* XBuf = 0;
		if (Flags & 0x04)  { // FLG.FEXTRA
			XLen = TAKE_(u16);
			XBuf = Location;
			Location += XLen;
		}

		// filename
		c8* FileName = 0;
		if (Flags & 0x08) { // FLG.FNAME
			FileName = Location;
			while (TAKE_(c8) != 0) {}
		}

		// comment
		c8* FileComment = 0;
		if (Flags & 0x10) { // FLG.FCOMMENT
			FileComment = Location;
			while (TAKE_(c8) != 0) {}
		}

		// hash
		u16 CRC16 = 0;
		if (Flags & 0x02) { // FLG.FHCRC
			CRC16 = TAKE_(u16);
		}

		u8* Stream      = Location;
		u64 BitLocation = 0;
		while (true) {
			u8 head = GzipFetchBits(Stream, &BitLocation, 3);
			switch ((head & 0x06) >> 1) {
				case 0: { // NO COMPRESSION
					break;
				}
				case 1: {
					// FIXED HUFFMAN
					break;
				}
				case 2: { // DYNAMIC HUFFMAN
					// Get huffman code counts
					u64 OperationCount = GzipFetchBits(Stream, &BitLocation, 5) + 257;
					u64 DistanceCount = GzipFetchBits(Stream, &BitLocation, 5) + 1;
					u64 CodeCount = GzipFetchBits(Stream, &BitLocation, 4) + 4;

					const u8 CodeLengthLengthOffset[19] = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 };
					u8 CodeLengthLengthProto[19] = {};
					for (int i = 0; i < CodeCount; i++) {
						CodeLengthLengthProto[CodeLengthLengthOffset[i]] = GzipFetchBits(Stream, &BitLocation, 3);
					}

					// Build our decoding tables
					u8  CodeLengthLengths[0x80] = {};
					u32 CodeLengthSymbols[0x80] = {};
					GzipPopulateTable(CodeLengthLengthProto, 19, CodeLengthLengths, CodeLengthSymbols, 0x80, 7);

					// Collect lengths
					u8   Lengths[OperationCount + DistanceCount] = {};
					uptr CurrentLen = 0;
					while (CurrentLen < OperationCount + DistanceCount) {
						// Fetch the next op
						u64 Code     = GzipFetchBits(Stream, &BitLocation, 7);
						u64 Error    = 7 - CodeLengthLengths[Code];
						BitLocation -= Error;
						    Code     = CodeLengthSymbols[Code];

						if (Code < 16) { // Literal
							Lengths[CurrentLen++] = Code;
						} else { // Repeat
							u64 RepeatCount = 0;
							u64 RepeatedLen = 0;

							if (Code == 16) {
								RepeatCount = 3 + GzipFetchBits(Stream, &BitLocation, 2);
								RepeatedLen = Lengths[CurrentLen - 1];
							} else if (Code == 17) {
								RepeatCount = 3 + GzipFetchBits(Stream, &BitLocation, 3);
							} else if (Code == 18) {
								RepeatCount =11 + GzipFetchBits(Stream, &BitLocation, 7);
							} else {} // Should be unreachable
							
							for (int i = 0; i < RepeatCount; i++) {
								Lengths[CurrentLen++] = RepeatedLen;
							}
						}
					}

					// Build operation table
					u32 OperationMaxBits = 0;
					for (int i = 0; i < OperationCount; i++) {
						if (Lengths[i] > OperationMaxBits) OperationMaxBits = Lengths[i];
					}

					// We have no standard library
					u32 OperationTableLength = 1;
					for (int i = 0; i < OperationMaxBits; i++) OperationTableLength *= 2;

					// Tables
					u32 OperationSymbols[OperationTableLength + 0x10];
					u8  OperationLengths[OperationTableLength + 0x10];
					GzipPopulateTable(Lengths, OperationCount,
							  OperationLengths, OperationSymbols, OperationTableLength,
							  OperationMaxBits);

					// Build distance tables
					u32 DistanceMaxBits = 0;
					for (int i = 0; i < DistanceCount; i++) {
						if (Lengths[OperationCount + i] > DistanceMaxBits) {
							DistanceMaxBits = Lengths[OperationCount + i];
						}
					}

					// Power
					u32 DistanceTableLength = 1;
					for (int i = 0; i < DistanceMaxBits; i++) DistanceTableLength *= 2;

					// More tables
					u32 DistanceSymbols[DistanceTableLength + 0x10]; // ad dsome padding theres some bugs i cant see
					u8  DistanceLengths[DistanceTableLength + 0x10];
					GzipPopulateTable(Lengths + OperationCount, DistanceCount,
							  DistanceLengths, DistanceSymbols, DistanceTableLength,
							  DistanceMaxBits);

					// Inflate
					while (Location + BitLocation / 8 < End) {
						// Fetch
						u32 Code     = GzipFetchBits(Stream, &BitLocation, OperationMaxBits);
						u64 Error    = OperationMaxBits - OperationLengths[Code];
						BitLocation -= Error;
						    Code     = OperationSymbols[Code];
						
						if (Code < 256) { // Literal
							*Output++ = Code;
						} else if (Code == 256) { // End block
							break;
						} else if (Code < OperationCount) { // Copy
							// You'd need to look at the spec yourself to understand this
							u32 CopyLength = GzipFetchBits(Stream, &BitLocation,
										         Code <= 284 && Code > 264
										       ? (Code - 1) / 4 - 65
										       : 0)
							               + LengthAddends[Code - 257];

							// More bitstream shit! I'm done formatting this stuff nicely.
							Code = GzipFetchBits(Stream, &BitLocation, DistanceMaxBits);
							Error = DistanceMaxBits - DistanceLengths[Code];
							BitLocation -= Error;
							Code = DistanceSymbols[Code];

							// make this a s32 because address and offset awaaa
							s32 Distance = GzipFetchBits(Stream, &BitLocation,
									               Code > 3 && Code <= 29
										     ? (Code - 2) / 2
										     : 0)
								     + DistanceAddends[Code];

							// Finally copy the thing
							for (int i = 0; i < CopyLength; i++) {
								Output[i] = Output[i - Distance];
							}

							Output += CopyLength;
						} else { while (1) {} } // Unreachable unless invalid
					}

					break;
				}
				case 3: {
					// THIS SHOULD NEVER HAPPEN
					break;
				}
			}
			if (head & 0x01) break; // BFINAL
		}

		// I think this is necessary but im not completely sure
		Location += BitLocation / 8 + 1;

		// trailer..? hey siri, what's the opposite of a header?
		u32 CRC32 = TAKE_(u32);
		u32 InputSize = TAKE_(u32);
	}

	return (void*)Output - BinaryStart;

	#undef TAKE_
}