package main import ( "cmp" "slices" ) type UnionFind struct { parent []int rank []int count int } func (uf *UnionFind) Count() int { return uf.count } func (uf *UnionFind) find(u int) int { if uf.parent[u] != u { uf.parent[u] = uf.find(uf.parent[u]) } return uf.parent[u] } func (uf *UnionFind) Union(u, v int) bool { rootU := uf.find(u) rootV := uf.find(v) if rootU == rootV { return false } if uf.rank[rootU] > uf.rank[rootV] { uf.parent[rootV] = rootU } else if uf.rank[rootU] < uf.rank[rootV] { uf.parent[rootU] = rootV } else { uf.parent[rootV] = rootU uf.rank[rootU]++ } uf.count-- return true } func makeUnionFind(n int) UnionFind { parent := make([]int, n) for i := range n { parent[i] = i } return UnionFind{ parent: parent, rank: make([]int, n), count: n, } } const ( Alice int = 1 Bob = 2 Both = 3 ) // by edge type — preferring «Both» type func byEdgeType(l, r []int) int { return -cmp.Compare(l[0], r[0]) } func maxNumEdgesToRemove(n int, edges [][]int) int { slices.SortFunc(edges, byEdgeType) alice := makeUnionFind(n) bob := makeUnionFind(n) added := 0 for _, edge := range edges { t, u, v := edge[0], edge[1]-1, edge[2]-1 switch t { case Alice: if alice.Union(u, v) { added++ } case Bob: if bob.Union(u, v) { added++ } case Both: connectsAlice := alice.Union(u, v) connectsBob := bob.Union(u, v) if connectsAlice || connectsBob { added++ } } } if alice.Count() > 1 || bob.Count() > 1 { return -1 } return len(edges) - added }