use std::{
    cmp::{max, min, Reverse},
    collections::{BTreeMap, BTreeSet},
    ops::Index,
    str::FromStr,
};

use aoc_2022::*;
use itertools::iproduct;

type Input = Graph;
type Output = i32;

#[derive(Debug)]
struct Vertex {
    rate: i32,
    next: Vec<String>,
}

impl Vertex {
    fn new(rate: i32, next: Vec<String>) -> Vertex {
        Vertex { rate, next }
    }
}

lazy_static! {
    static ref REGEX: Regex = Regex::new(
        r"Valve (?P<v>\w+) has flow rate=(?P<r>-?\d+); tunnel(s)? lead(s)? to valve(s)? (?P<next>.*)"
    )
    .unwrap();
}

#[derive(Debug)]
struct Graph {
    g: BTreeMap<String, Vertex>,
    to_open: BTreeSet<String>,
    distances: BTreeMap<(String, String), i32>,
}

impl FromStr for Graph {
    type Err = Report;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let s = s.trim_end();

        let mut g: BTreeMap<String, Vertex> = BTreeMap::new();
        for line in s.lines() {
            let caps = REGEX.captures(line).unwrap();

            let v = caps["v"].to_string();
            let r: i32 = caps["r"].parse()?;
            let next = caps["next"]
                .split(", ")
                .map(|n| n.to_string())
                .collect_vec();

            g.insert(v, Vertex::new(r, next));
        }

        Ok(Graph::new(g))
    }
}

impl Index<&str> for Graph {
    type Output = Vertex;

    fn index(&self, index: &str) -> &Self::Output {
        &self.g[index]
    }
}

fn floyd_warshall(g: &BTreeMap<String, Vertex>) -> BTreeMap<(String, String), i32> {
    let upper_bound = i32::MAX / 2;

    let mut distances = BTreeMap::new();
    for (u, v) in iproduct!(g.keys(), g.keys()) {
        let is_neighbour = g[u].next.contains(v);
        distances.insert(
            (u.clone(), v.clone()),
            if is_neighbour { 1 } else { upper_bound },
        );
    }

    for (w, u, v) in iproduct!(g.keys(), g.keys(), g.keys()) {
        let key = (u.clone(), v.clone());
        let current = *distances.get(&key).unwrap();

        distances.insert(
            key,
            min(
                current,
                distances[&(u.clone(), w.clone())] + distances[&(w.clone(), v.clone())],
            ),
        );
    }

    distances
}

impl Graph {
    fn new(g: BTreeMap<String, Vertex>) -> Self {
        let mut to_open: BTreeSet<String> = BTreeSet::new();
        for (id, v) in g.iter() {
            if v.rate > 0 {
                to_open.insert(id.clone());
            }
        }

        let distances = floyd_warshall(&g);

        Self {
            g,
            to_open,
            distances,
        }
    }

    fn distance(&self, u: &str, v: &str) -> i32 {
        self.distances[&(u.to_string(), v.to_string())]
    }
}

fn max_flow(
    g: &Input,
    cache: &mut BTreeMap<String, Output>,
    valves: &mut BTreeSet<String>,
    start: &str,
    t: i32,
) -> Output {
    let key = format!(
        "{}-{}-{}",
        t,
        start,
        valves
            .iter()
            .sorted_by_key(|&v| Reverse(g.g[v].rate))
            .join("-")
    );

    if let Some(&flow) = cache.get(&key) {
        // we have already precomputed the flow
        return flow;
    }

    let flow_from_start = g[start].rate * t;

    let mut flow_from_next = 0;
    for v in valves.clone().iter().collect_vec() {
        let d = g.distance(start, v);

        if t < d + 1 {
            continue;
        }

        valves.remove(v);
        flow_from_next = max(flow_from_next, max_flow(g, cache, valves, v, t - d - 1));
        valves.insert(v.clone());
    }

    cache.insert(key, flow_from_start + flow_from_next);
    flow_from_start + flow_from_next
}

fn pairings(
    valves: &BTreeSet<String>,
) -> impl Iterator<Item = (BTreeSet<String>, BTreeSet<String>)> + '_ {
    let mapping = valves.iter().collect_vec();

    let max_mask = 1 << (valves.len() - 1);

    (0..max_mask).map(move |mask| {
        let mut elephant = BTreeSet::new();
        let mut human = BTreeSet::new();

        for (i, &v) in mapping.iter().enumerate() {
            if (mask & (1 << i)) == 0 {
                human.insert(v.clone());
            } else {
                elephant.insert(v.clone());
            }
        }

        (human, elephant)
    })
}

struct Day16;
impl Solution<Input, Output> for Day16 {
    fn parse_input<P: AsRef<Path>>(pathname: P) -> Input {
        file_to_string(pathname).parse().unwrap()
    }

    fn part_1(input: &Input) -> Output {
        let mut cache = BTreeMap::new();
        max_flow(input, &mut cache, &mut input.to_open.clone(), "AA", 30)
    }

    fn part_2(input: &Input) -> Output {
        let mut cache: BTreeMap<String, Output> = BTreeMap::new();

        pairings(&input.to_open)
            .map(|(mut human, mut elephant)| {
                max_flow(input, &mut cache, &mut human, "AA", 26)
                    + max_flow(input, &mut cache, &mut elephant, "AA", 26)
            })
            .max()
            .unwrap()
    }
}

fn main() -> Result<()> {
    // Day16::run("sample")
    Day16::main()
}

test_sample!(day_16, Day16, 1651, 1707);