class Solution {
    fun <A, B> product(
        xs: Sequence<A>,
        ys: Sequence<B>,
    ): Sequence<Pair<A, B>> = xs.flatMap { x -> ys.map { y -> x to y } }

    fun <A, B> product(
        xs: Iterable<A>,
        ys: Iterable<B>,
    ): Sequence<Pair<A, B>> = product(xs.asSequence(), ys.asSequence())

    fun countSquares(matrix: Array<IntArray>): Int {
        val (rows, columns) = matrix.size to matrix[0].size
        val dp = Array(rows + 1) { IntArray(columns + 1) }

        var answer = 0
        for ((y, x) in product(0..<rows, 0..<columns).filter { (y, x) ->
            matrix[y][x] == 1
        }) {
            dp[y + 1][x + 1] = 1 + listOf(dp[y][x + 1], dp[y + 1][x], dp[y][x]).min()
            answer += dp[y + 1][x + 1]
        }

        return answer
    }
}