diff --git a/build.gradle.kts b/build.gradle.kts index f62a589..b97fbab 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -5,6 +5,9 @@ plugins { repositories { mavenCentral() } +dependencies { + implementation("org.junit.jupiter:junit-jupiter:5.7.0") +} tasks { sourceSets { diff --git a/src/year2021/day16/Day16.kt b/src/year2021/day16/Day16.kt new file mode 100644 index 0000000..cf70042 --- /dev/null +++ b/src/year2021/day16/Day16.kt @@ -0,0 +1,150 @@ +package year2021.day16 + +import readInput + +enum class PacketType(val id: Byte) { + Sum(0), + Product(1), + Minimum(2), + LiteralValue(4), + Maximum(3), + GreaterThan(5), + LessThan(6), + EqualTo(7), +} + +enum class LengthType(val id: Byte) { + TotalLength(0), + Count(1) +} + +private fun Byte.toPacketType(): PacketType = PacketType.values().first { it.id == this } +private fun Byte.toLengthType(): LengthType = LengthType.values().first { it.id == this } + +data class Packet( + val version: Byte, + val type: PacketType, + val bytes: List = emptyList(), + val packets: List = emptyList() +) { + fun versionSum(): Int = when (type) { + PacketType.LiteralValue -> version.toInt() + else -> version.toInt() + packets.sumOf(Packet::versionSum) + } + + fun eval(): Long = when (type) { + PacketType.Sum -> packets.sumOf(Packet::eval) + PacketType.Product -> packets.fold(1) { result, packet -> result * packet.eval() } + PacketType.Minimum -> packets.minOf(Packet::eval) + PacketType.LiteralValue -> bytes.fold(0) { result, byte -> result.shl(4).or(byte.toLong()) } + PacketType.Maximum -> packets.maxOf(Packet::eval) + PacketType.GreaterThan -> if (packets[0].eval() > packets[1].eval()) 1 else 0 + PacketType.LessThan -> if (packets[0].eval() < packets[1].eval()) 1 else 0 + PacketType.EqualTo -> if (packets[0].eval() == packets[1].eval()) 1 else 0 + } +} + +fun List.chunk(from: Int, to: Int): UInt { + check(to - from <= 32) + + var result = 0u + var usedBits = 0 + var i = from.div(4) + + // prefix + if (from.mod(4) != 0) { + val offset = 4 - from.mod(4) + val offsetMask = 1.shl(offset) - 1 + + result = this[i].and(offsetMask).toUInt() + + usedBits += offset + i++ + } + + while (4 * i < to) { + val fromFirst = this[i].toUInt() + result = result.shl(4).or(fromFirst) + + i++ + usedBits += 4 + } + + return result.shr(usedBits - (to - from)) +} + +fun parseGroups(bytes: List, groups: MutableList, from: Int): Int { + val mask = (1.shl(4) - 1).toUInt() + + var i = 0 + var lastByte = false + while (!lastByte) { + val byte = bytes.chunk(from + 5 * i, from + 5 * (i + 1)) + groups.add(byte.and(mask).toByte()) + + i++ + lastByte = byte.shr(4) == 0u + } + + return from + 5 * i +} + +fun parsePackets(bytes: List, packets: MutableList, from: Int): Int { + var index = from + 1 + when (bytes.chunk(from, from + 1).toByte().toLengthType()) { + LengthType.TotalLength -> { + val packetsSize = bytes.chunk(index, index + 15).toInt() + index += 15 + + var read = 0 + while (read < packetsSize) { + val (packet, newIndex) = parsePacket(bytes, index) + + packets.add(packet) + read += newIndex - index + index = newIndex + } + } + LengthType.Count -> { + val packetCount = bytes.chunk(index, index + 11).toInt() + index += 11 + + repeat(packetCount) { + val (packet, newIndex) = parsePacket(bytes, index) + + packets.add(packet) + index = newIndex + } + } + } + + return index +} + +fun parsePacket(bytes: List, from: Int): Pair { + val version = bytes.chunk(from, from + 3).toByte() + val typeId = bytes.chunk(from + 3, from + 6).toByte().toPacketType() + + val byteGroups = mutableListOf() + val packets = mutableListOf() + + val offset = when (typeId) { + PacketType.LiteralValue -> parseGroups(bytes, byteGroups, from + 6) + else -> parsePackets(bytes, packets, from + 6) + } + + return Pair(Packet(version, typeId, byteGroups, packets), offset) +} + +fun String.toPacket(): Packet = + parsePacket(this.map { it.digitToInt(16) }, 0).first + +fun part1(input: String): Int = input.toPacket().versionSum() +fun part2(input: String): Long = input.toPacket().eval() + +fun main() { + val input = readInput(16, "input").first() + + println(part1(input)) + println(part2(input)) +} diff --git a/src/year2021/day16/Day16KtTest.kt b/src/year2021/day16/Day16KtTest.kt new file mode 100644 index 0000000..ac6fc64 --- /dev/null +++ b/src/year2021/day16/Day16KtTest.kt @@ -0,0 +1,100 @@ +package year2021.day16 + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +internal class Day16KtTest { + + @Test + fun chunkWithOneElement() { + val list = listOf(15) + assertEquals(1u, list.chunk(0, 1)) + assertEquals(3u, list.chunk(0, 2)) + assertEquals(7u, list.chunk(0, 3)) + assertEquals(15u, list.chunk(0, 4)) + } + + @Test + fun chunk() { + val list = listOf(15, 12, 10, 3) + assertEquals(15u.shl(4).or(12u).shl(4).or(10u).shl(4).or(3u), list.chunk(0, 16)) + } + + @Test + fun chunkWholeInt() { + val list = listOf(15, 12, 10, 3, 6) + assertEquals(15u.shl(4).or(12u).shl(4).or(10u).shl(4).or(3u).shl(4).or(6u), list.chunk(0, 20)) + } + + @Test + fun parsing() { + assertEquals( + Packet( + 6.toByte(), + PacketType.LiteralValue, + bytes = listOf(7, 14, 5).map(Int::toByte), + ), + "D2FE28".toPacket() + ) + } + + @Test + fun part1a() { + assertEquals(16, part1("8A004A801A8002F478")) + } + + @Test + fun part1b() { + assertEquals(12, part1("620080001611562C8802118E34")) + } + + @Test + fun part1c() { + assertEquals(23, part1("C0015000016115A2E0802F182340")) + } + + @Test + fun part1d() { + assertEquals(31, part1("A0016C880162017C3686B18A3D4780")) + } + + @Test + fun part2Sum() { + assertEquals(3, part2("C200B40A82")) + } + + @Test + fun part2Product() { + assertEquals(54, part2("04005AC33890")) + } + + @Test + fun part2Minimum() { + assertEquals(7, part2("880086C3E88112")) + } + + @Test + fun part2Maximum() { + assertEquals(9, part2("CE00C43D881120")) + } + + @Test + fun part2LessThan() { + assertEquals(1, part2("D8005AC2A8F0")) + } + + @Test + fun part2GreaterThan() { + assertEquals(0, part2("F600BC2D8F")) + } + + @Test + fun part2EqualTo() { + assertEquals(0, part2("9C005AC2F8F0")) + } + + @Test + fun part2Complex() { + assertEquals(1, part2("9C0141080250320F1802104A08")) + } +} \ No newline at end of file