summaryrefslogtreecommitdiff
path: root/src/bin/day12.rs
blob: 9a485e459b5c4a5a35f0d7c10cdb109dcf0226db (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::collections::HashMap;
use std::collections::HashSet;

fn main() {
    let input = advent::read_lines(12);
    println!("12a: {}", paths_through_caves(&input, false));
    println!("12b: {}", paths_through_caves(&input, true));
}

fn is_small_cave(cave: &str) -> bool {
    !cave.chars().any(|c| c.is_uppercase())
}

fn visiting_allowed(visited: &HashMap<&str, usize>, cave: &str, allow_twice: bool) -> bool {
    if cave == "start" {
        return false;
    }

    if !is_small_cave(cave) {
        return true;
    }

    if !allow_twice {
        return !visited.contains_key(cave);
    }

    match visited.get(cave) {
        None => true,
        Some(count) => {
            assert!(*count > 0);
            !visited.iter().any(|(cave,c)| is_small_cave(cave) && *c == 2)
        }
    }
}

fn count_paths<'a>(connections: &HashMap<&str, HashSet<&'a str>>, visited: &mut HashMap<&'a str, usize>, from: &'a str, to: &str, allow_twice: bool) -> usize {
    if from == to {
        return 1;
    }
    let neighbors = connections.get(from).unwrap();
    let visits = visited.entry(from).or_insert(0);
    *visits += 1;

    neighbors.iter()
             .filter(|neighbor| visiting_allowed(visited, neighbor, allow_twice))
             .map(|neighbor| count_paths(connections, &mut visited.clone(), neighbor, to, allow_twice))
             .sum()
}

fn paths_through_caves<T: AsRef<str>>(input: &[T], allow_twice: bool) -> usize {
    let mut connections = HashMap::new();

    for line in input {
        let (start, end) = line.as_ref().split_once('-').unwrap();

        let conn1 = connections.entry(start).or_insert_with(HashSet::new);
        conn1.insert(end);

        let conn2 = connections.entry(end).or_insert_with(HashSet::new);
        conn2.insert(start);
    }

    let mut visited = HashMap::new();
    count_paths(&connections, &mut visited, "start", "end", allow_twice)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test0() {
        let input = [
            "start-A",
            "start-b",
            "A-c",
            "A-b",
            "b-d",
            "A-end",
            "b-end",
        ];
        assert_eq!(paths_through_caves(&input, false), 10);
        assert_eq!(paths_through_caves(&input, true), 36);
    }

    #[test]
    fn test1() {
        let input = [
            "dc-end",
            "HN-start",
            "start-kj",
            "dc-start",
            "dc-HN",
            "LN-dc",
            "HN-end",
            "kj-sa",
            "kj-HN",
            "kj-dc",
        ];
        assert_eq!(paths_through_caves(&input, false), 19);
    }

    #[test]
    fn test2() {
        let input = [
            "fs-end",
            "he-DX",
            "fs-he",
            "start-DX",
            "pj-DX",
            "end-zg",
            "zg-sl",
            "zg-pj",
            "pj-he",
            "RW-he",
            "fs-DX",
            "pj-RW",
            "zg-RW",
            "start-pj",
            "he-WI",
            "zg-he",
            "pj-fs",
            "start-RW"
        ];
        assert_eq!(paths_through_caves(&input, false), 226);
    }
}