bzoj 3926
trie树后缀自动机,直接用val值统计答案
注意细节,p,q不要打反;注意删调
#include#include #include #include using namespace std;#define maxn 2000020#define maxm 100020struct node{ int next[12];}trie[maxn];struct node1{ int next[12],val,pnt;}sam[maxn * 2];struct node2{ int to,next;}e[maxm * 2];int tot,cnt,head[maxm],root;long long sum;int n,col[maxm],c,degree[maxm],last;inline void addedge(int x,int y){ e[++cnt].to = y; e[cnt].next = head[x]; head[x] = cnt; degree[x]++;}inline int add_sam(int x,int last){ int np = ++tot; int p = last; sam[np].val = sam[p].val + 1; while ( p && !sam[p].next[x] ){ sam[p].next[x] = np , p = sam[p].pnt; } int q = sam[p].next[x]; if ( !q ){ sam[p].next[x] = np; sam[np].pnt = p; } else if ( q && sam[p].val + 1 == sam[q].val ){ sam[np].pnt = q; } else{ int nq = ++tot; sam[nq].pnt = sam[q].pnt; sam[np].pnt = sam[q].pnt = nq; sam[nq].val = sam[p].val + 1; memcpy(sam[nq].next,sam[q].next,sizeof(sam[q].next)); while ( p && sam[p].next[x] == q ){ sam[p].next[x] = nq , p = sam[p].pnt; } if ( sam[p].next[x] == q ) sam[p].next[x] = nq; } return np;}void dfs_trie(int now,int last){ int cur = last; for (int i = 0 ; i <= c; i++){ if ( trie[now].next[i] ){ last = add_sam(i,cur); dfs_trie(trie[now].next[i],last); } }}void dfs(int now,int &trie_now,int fa){ if ( !trie_now ) trie_now = ++tot; for (int i = head[now] ; i ; i = e[i].next){ if ( fa == e[i].to ) continue; dfs(e[i].to,trie[trie_now].next[col[e[i].to]],now); } }int main(){ scanf("%d %d",&n,&c); for (int i = 1 ; i <= n ; i++) scanf("%d",&col[i]); for (int i = 1 ; i < n ; i++){ int u,v; scanf("%d %d",&u,&v); addedge(u,v); addedge(v,u); } for (int i = 1 ; i <= n ; i++){ if ( degree[i] == 1 ){ dfs(i,trie[0].next[col[i]],0); } } tot = 0; dfs_trie(0,root); for (int i = 1 ; i <= tot ; i++){ sum += (long long) sam[i].val - (long long) sam[sam[i].pnt].val; } printf("%lld",sum); return 0;}