import { TagArithmeticsOperation } from '@lightly/api-spec';

export function simplePow(x: bigint, y: bigint) {
	let calced = 1n;
	for (let i = 0, e = y; i < e; i++) {
		calced *= x
	}
	return calced;
}



export class BitMask {
	#mask: bigint;
	#size: number;
	constructor(mask: string | bigint) {
		if (typeof mask === 'string') {
			this.#mask = BigInt(mask)
		}
		else {
			this.#mask = BigInt( simplePow(2n, mask) - 1n)
		}
		this.#size = this.maxSize();
	}
	toString() {
		return `0x${this.#mask.toString(16)}`
	}
	maxSize() {
		return this.#mask.toString(2).length
	}
	get [Symbol.toStringTag]() {
		return this.toString()
	}
	toJSON() {
		return this.toString()
	}
	toBinary(length = 0) {
		return `0b${this.#mask.toString(2).padStart(length, '0')}`
	}
	toBigInt() {
		return BigInt(this.#mask)
	}
	invert(nSamples: bigint) {
		this.#mask = this.#mask ^ ( simplePow(2n, nSamples) - 1n)
		return this;
	}
	complement(nSamples: bigint) {
		this.invert(nSamples)
		return this;
	}
	union(other: BitMask) {
		this.#mask = this.#mask | other.toBigInt();
		return this;
	}
	intersect(other: BitMask) {
		this.#mask = this.#mask & other.toBigInt();
		return this;
	}
	difference(other: BitMask, nSamples: bigint) {
		this.#mask = this.#mask & other.invert( nSamples ).toBigInt()
		return this;
	}

	fromIndices(indices: number[]) {
		this.#mask = 0n;
		indices.forEach((i) => this.setKthBit(BigInt(i)));
		return this;
	}

	toIndicesOld() {
		let x = BigInt(this.#mask)
		const indices: Array<number> = [];
		for (let i = 0; x > 0; i++) {
			// if the number is odd, there is a nonzero bit at offset
			if (x % 2n > 0) {
				indices.push(i)
			}
			// increment the offset and divide the number x by two (rounding down)
			x = x / 2n
		}
		return indices;
	}

	// see https://stackoverflow.com/questions/19365480/best-method-to-find-out-set-bit-positions-in-a-bit-mask-in-c
	toIndices() {
		let x = BigInt(this.#mask)
		const indices: Array<number> = new Array(this.#size);
		let index = 0;
		let found = 0;
		while (x) {
			if (x & 1n) {
				indices[found++] = index;
			}
			x >>= 1n;
			++index;
		}

		return indices.slice(0, found);
	}

	getAmountBitSet(): number {
		return this.#mask.toString(2).padStart(this.#size, '0').replace(/0/g,'').length
	}
	getAmountBitUnset(): number {
		return this.#mask.toString(2).padStart(this.#size, '0').replace(/1/g,'').length
	}

	doForIndices(fn: (indice: number) => void) {
		let x = BigInt(this.#mask)
		let index = 0;
		while (x) {
			if (x & 1n) {
				fn(index)
			}
			x >>= 1n;
			index++;
		}
	}

	// get kth bit from right
	getKthBit(k: bigint) {
		return (this.#mask & (1n << k)) >> k;
	}
	// set kth bit from right to 1
	setKthBit(k: bigint) {
		this.#mask = (1n << k) | this.#mask;
		return this;
	}
	// set kth bit from right to 0
	unsetKthBit(k: bigint) {
		this.#mask = (this.#mask & ~(1n << k));
		return this;
	}

	// check if kth bit from right is set
	isKthBitSet(k: bigint): boolean {
		return (this.getKthBit(k) === 1n);
	}

	// check if kth bit from right is set
	isKthBitUnset(k: bigint): boolean {
		return (this.getKthBit(k) === 0n);
	}
}

// gets the difference (added, removed) between a reference and a modified bitmask
export const getBitmaskDifference = (reference: BitMask, modified: BitMask, totSize?: number): {added: number, removed: number} => {
	let added = 0, removed = 0;

	// calc the points added
	const maxSize = BigInt(totSize || Math.max(reference.maxSize(), modified.maxSize() ) || 0);
	const pointsNotInReference = new BitMask(modified.toString());
	pointsNotInReference.difference(new BitMask(reference.toString()), maxSize);
	added = pointsNotInReference.getAmountBitSet();

	// calc the points removed
	const pointsNotInModifed = new BitMask(modified.toString());
	pointsNotInModifed.invert(maxSize);
	pointsNotInModifed.intersect(reference);
	removed = pointsNotInModifed.getAmountBitSet();

	return {
		added,
		removed,
	}
}


export function performBitmaskArithmetics(bitmask1: BitMask, bitmask2: BitMask, operation: TagArithmeticsOperation, nSamples?: number) {
	switch (operation) {
		case TagArithmeticsOperation.UNION:
			bitmask1.union(bitmask2);
			break;
		case TagArithmeticsOperation.INTERSECTION:
			bitmask1.intersect(bitmask2);
			break;
		case TagArithmeticsOperation.DIFFERENCE:
			if (nSamples === undefined) {
				throw new Error('Cant perform bitmask arithmetics difference as nSamples is undefined')
			}
			bitmask1.difference(bitmask2, BigInt(nSamples))
			break;
	}

	return bitmask1
}
