#[derive(Eq, PartialEq, Copy, Clone, Debug)]
enum Amphipod {
    A = 0,
    B = 1,
    C = 2,
    D = 3,
}

impl Amphipod {
    fn energy(&self) -> u64 {
        match self {
            Amphipod::A => 1,
            Amphipod::B => 10,
            Amphipod::C => 100,
            Amphipod::D => 1000,
        }
    }
}

#[derive(Eq, PartialEq, Clone, Debug)]
struct Node {
    name: String,
    occupants: Vec<Amphipod>,
    capacity: usize,
    exit_for: Option<Amphipod>,
    edges: Vec<Edge>,
    prev: Vec<Option<usize>>,
    dist: Vec<Option<u64>>,
}
#[derive(Eq, PartialEq, Clone, Debug)]
struct Edge {
    direct: bool,
    cost: u64,
    blockers: Vec<usize>,
}

impl Edge {
    fn default() -> Self {
        Self {
            direct: false,
            cost: u64::MAX,
            blockers: vec![],
        }
    }

    fn default_vec(n: usize) -> Vec<Self> {
        std::iter::repeat(Edge::default()).take(n).collect()
    }
}

impl Node {
    fn new_sideroom(color: Amphipod, occupants: &[Amphipod]) -> Self {
        Self {
            name: format!("sideroom_{color:?}"),
            occupants: occupants.to_vec(),
            capacity: occupants.len(),
            exit_for: Some(color),
            edges: Edge::default_vec(9),
            prev: [None; 9].to_vec(),
            dist: [None; 9].to_vec(),
        }
    }

    fn new_reserve(capacity: usize, name: &str) -> Self {
        Self {
            name: name.to_string(),
            occupants: vec![],
            capacity,
            exit_for: None,
            edges: Edge::default_vec(9),
            prev: [None; 9].to_vec(),
            dist: [None; 9].to_vec(),
        }
    }

    fn edges(&self) -> Vec<(usize, Edge)> {
        self.edges
            .iter()
            .cloned()
            .enumerate()
            .filter(|(_, e)| e.direct)
            .collect()
    }
}

fn dijkstra(graph: &mut [Node], start: usize) {
    fn next_node(graph: &[Node], queue: &mut Vec<usize>, start: usize) -> Option<usize> {
        queue.sort_by_key(|n| 0 - graph[*n].dist[start].unwrap() as isize);
        queue.pop()
    }
    let mut queue = vec![start];
    while let Some(node) = next_node(graph, &mut queue, start) {
        let dist = graph[node].dist[start].unwrap();
        let new_dist = dist + 2;

        for (other_id, _) in graph[node].edges() {
            let other_node = graph[other_id].clone();

            if other_node.dist[start].is_none() {
                queue.push(other_id);
            }
            if other_node.dist[start].is_none_or(|prev_dist| prev_dist > new_dist) {
                graph[other_id].dist[start] = Some(new_dist);
                graph[other_id].prev[start] = Some(node);
            }
        }
    }
}

fn make_graph() -> Vec<Node> {
    let mut graph: Vec<Node> = vec![];
    let mut create = |n: Node| {
        graph.push(n);
        graph.len() - 1
    };
    use Amphipod::*;
    let (sr_a, sr_b, sr_c, sr_d);
    sr_a = create(Node::new_sideroom(A, &[A, D, D, B]));
    sr_b = create(Node::new_sideroom(B, &[D, B, C, C]));
    sr_c = create(Node::new_sideroom(C, &[C, A, B, B]));
    sr_d = create(Node::new_sideroom(D, &[A, C, A, D]));
    let hw_left = create(Node::new_reserve(2, "hw_left"));
    let hw_right = create(Node::new_reserve(2, "hw_right"));
    let pause1 = create(Node::new_reserve(1, "pause_1"));
    let pause2 = create(Node::new_reserve(1, "pause_2"));
    let pause3 = create(Node::new_reserve(1, "pause_3"));

    for (a, b) in [
        (hw_left, pause1),
        (pause1, pause2),
        (pause2, pause3),
        (pause3, hw_right),
        (sr_a, hw_left),
        (sr_a, pause1),
        (sr_b, pause1),
        (sr_b, pause2),
        (sr_c, pause2),
        (sr_c, pause3),
        (sr_d, pause3),
        (sr_d, hw_right),
    ] {
        graph[a].edges[b].direct = true;
        graph[b].edges[a].direct = true;
    }

    // Run dijkstra and simplify as a list of edges with a list of "blockers"
    for end in 0..graph.len() {
        graph[end].dist[end] = Some(0);
        dijkstra(&mut graph, end);
        for start in 0..graph.len() {
            if end == start {
                continue;
            };
            let mut current = start;
            let cost = graph[start].dist[end].unwrap();
            let mut blockers = vec![];
            loop {
                current = graph[current].prev[end].unwrap();
                if current == end {
                    break;
                }
                blockers.push(current);
            }
            graph[end].edges[start].blockers = blockers;
            graph[end].edges[start].cost = cost;
        }
    }

    for amph in [A, B, C, D] {
        let sideroom_id = amph as usize;

        // Remove amphipods already at the bottom of their exit room
        while graph[sideroom_id].occupants[0] == amph {
            graph[sideroom_id].occupants.remove(0);
            graph[sideroom_id].capacity -= 1;
        }
    }
    graph
}

fn do_move(graph: &mut [Node], from: usize, dest: usize) -> u64 {
    let amphipod = graph[from].occupants.pop().unwrap();
    // Cost of that motion
    let move_cost = amphipod.energy() * graph[from].edges[dest].cost;
    assert!(graph[from].edges[dest].cost == graph[dest].edges[from].cost);
    // Remaining occupants of the start node move one step towards this node's "exit"
    let leave_cost = graph[from]
        .occupants
        .iter()
        .map(Amphipod::energy)
        .sum::<u64>();
    // Existing occupants of the dest node move one step towards the back of that node.
    let entry_cost = graph[dest]
        .occupants
        .iter()
        .map(Amphipod::energy)
        .sum::<u64>();
    graph[dest].occupants.push(amphipod);
    leave_cost + move_cost + entry_cost
}

fn main() {
    let mut graph = make_graph();
    let sr_a = 0;
    let sr_b = 1;
    let sr_c = 2;
    let sr_d = 3;
    let hw_left = 4;
    let hw_right = 5;
    let pause1 = 6;
    let pause2 = 7;
    let pause3 = 8;

    let mut cost = 0;
    for (start, dest) in [
        // Actual solution from the example
        (sr_d, hw_right),
        (sr_d, hw_left),
        (sr_c, hw_right),
        (sr_c, pause3),
        (sr_c, hw_left),
        (sr_b, sr_c),
        (sr_b, sr_c),
        (sr_b, pause2),
        (sr_b, pause1),
        (pause2, sr_b),
        (pause3, sr_b),
        (hw_right, sr_b),
        (sr_d, sr_c),
        (sr_d, hw_right),
        (pause1, sr_d),
        (sr_a, sr_b),
        (sr_a, sr_d),
        (sr_a, pause1),
        (hw_left, sr_a),
        (hw_left, sr_a),
        (pause1, sr_d),
        (hw_right, sr_a),
        (hw_right, sr_d),
    ] {
        cost += do_move(&mut graph, start, dest);
    }
    println!("Total: {cost}");
}