import { getMesh } from "@/data/assetLoader/asset.store";
import { FogOutlines } from "@/lib/shaders/outline";
import type { Material } from "@/types";
import { useFrame } from "@react-three/fiber";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { PointOctree } from "sparse-octree";
import { Matrix4, type Mesh, Quaternion, Vector3 } from "three";

type Boid = {
	position: Vector3;
	velocity: Vector3;
	direction: Vector3;
	size: number;
	color: number;
};

const maxBoids = 4;

export const FishBoid = ({
	boid,
	applyForces,
}: {
	boid: Boid;
	applyForces: (boid: Boid, delta: number) => Vector3;
}) => {
	const boidRef = useRef<Mesh>(null!);
	const { size } = boid;

	useFrame((_state, delta) => {
		const p = applyForces(boid, delta);
		boidRef.current.position.copy(p);
		const forward = new Vector3(0, 0, 1);
		const targetDirection = boid.direction.normalize();
		const rotationQuaternion = new Quaternion().setFromUnitVectors(
			forward,
			targetDirection,
		);
		const rotationMatrix = new Matrix4().makeRotationFromQuaternion(
			rotationQuaternion,
		);
		const upVector = new Vector3(0, 1, 0).applyMatrix4(rotationMatrix);
		const finalQuaternion = new Quaternion().setFromUnitVectors(
			upVector,
			new Vector3(0, 1, 0),
		);

		finalQuaternion.multiply(rotationQuaternion);
		const tiltFactor = 0.3; // Control the amount of tilt: 0 means fully vertical, 1 means full natural tilt
		finalQuaternion.slerp(rotationQuaternion, tiltFactor);
		boidRef.current.quaternion.copy(finalQuaternion);
		boidRef.current.position.copy(boid.position);
	});

	const { mesh, tex } = useMemo(() => {
		const mesh = getMesh("herring/herring/Herring_LOD1");
		if (!mesh) return {};
		return { mesh: mesh.geometry, tex: (mesh.material as Material).map };
	}, []);

	return (
		<mesh
			ref={boidRef}
			position={boid.position}
			scale={[size, size, size]}
			geometry={mesh}
		>
			<meshToonMaterial color={0xffffff} map={tex} />
			{/* <FogOutlines
				thickness={0.8}
				color="#534646"
				opacity={0.8}
				transparent
				screenspace
			/> */}
		</mesh>
	);
};

export const BoidSystem = () => {
	const octree = useRef(
		new PointOctree<Boid>(
			new Vector3(-10, -10, -10),
			new Vector3(10, 10, 10),
			0.01,
			30,
		),
	);
	const [boids, setBoids] = useState<Boid[]>([]);
	const boidTarget = useRef(new Vector3());

	const forceAlignment = 1.3;
	const alignRadius = 0.2;
	const forceCohesion = 1;
	const coheseRadius = 0.35;
	const forceSteer = 0.5;
	const steerRadius = 0.15;

	const forceContainment = 10;
	const speed = 0.7;
	const range = 0.5;

	useEffect(() => {
		const b = Array.from({ length: maxBoids }, () => ({
			position: new Vector3(
				Math.random() * 0.5 - 0.25,
				Math.random() * 0.5 - 0.25,
				Math.random() * 0.5 - 0.25,
			),
			velocity: new Vector3(),
			direction: new Vector3(),
			size: (Math.random() * 0.5 + 0.5) * 0.06,
			color: Math.random() * 0xffffff,
		}));
		b.forEach((boid) => octree.current.set(boid.position, boid));
		setBoids(b);
		return () => {
			octree.current = new PointOctree<Boid>(
				new Vector3(-10, -10, -10),
				new Vector3(10, 10, 10),
				0.01,
				30,
			);
		};
	}, []);

	useFrame((state) => {
		const dt = state.clock.elapsedTime;
		if (dt % 33 === 0) {
			boidTarget.current = new Vector3(
				Math.random() * 4 - 2,
				Math.random() * 4 - 2,
				Math.random() * 4 - 2,
			);
		}
	});

	const applyForces = useCallback((boid: Boid, delta: number): Vector3 => {
		const neighbors = octree.current.findPoints(boid.position, 30);
		const neighborCount = neighbors.length;
		if (neighborCount === 0) return boid.position;
		const force = boid.velocity.clone();
		const avgDir = new Vector3(0, 0, 0);
		const steer = new Vector3(0, 0, 0);

		const cohesionDir = new Vector3(0, 0, 0);
		let coheseCount = 0;
		let steerCount = 0;
		let alignCount = 0;

		for (let i = 0; i < neighborCount; i++) {
			const neighbor = neighbors[i].data!;
			const d = boid.position.distanceTo(neighbor.position);
			if (d <= 0 || d > range) continue;
			if (d <= alignRadius) {
				alignCount++;
				avgDir.add(neighbor.velocity.clone().normalize().divideScalar(d));
			}
			if (d <= steerRadius) {
				steerCount++;
				steer.add(
					boid.position
						.clone()
						.sub(neighbor.position)
						.normalize()
						.divideScalar(d),
				);
			}
			if (d <= coheseRadius) {
				coheseCount++;
				neighbor.position
					.clone()
					.sub(boid.position)
					.normalize()
					.divideScalar(d);
			}
		}
		if (alignCount > 0)
			avgDir.divideScalar(neighborCount).multiplyScalar(forceAlignment);
		if (steerCount > 0)
			steer.divideScalar(neighborCount).multiplyScalar(forceSteer);
		if (coheseCount > 0)
			cohesionDir.divideScalar(coheseCount).multiplyScalar(forceCohesion);
		const dZero = new Vector3(0, 0, 0).distanceTo(boid.position);
		const zeroDir = boidTarget.current.clone().sub(boid.position).normalize();
		if (dZero > 0.35 || boid.velocity.length() < 0.0001) {
			force.add(zeroDir.multiplyScalar(dZero * forceContainment));
		}
		force.add(avgDir).add(steer).add(cohesionDir); //.add(rnd);

		boid.velocity = boid.velocity
			.add(force.multiplyScalar(delta * 0.5))
			.divideScalar(2);
		const targetPos = boid.position
			.clone()
			.add(boid.velocity.clone().multiplyScalar(delta * 10));
		octree.current.remove(boid.position);
		boid.position.lerp(targetPos, speed * delta * 50);
		boid.direction = boid.direction.lerp(
			boid.velocity.clone().normalize(),
			delta * 5.8,
		);
		octree.current.set(boid.position, boid);
		return boid.position;
	}, []);

	return (
		<group>
			{boids.map((boid, idx) => {
				return <FishBoid key={idx} boid={boid} applyForces={applyForces} />;
			})}
		</group>
	);
};
