Solution
写在前面
本文结论没有证明,想要详细证明的请看_Vix_大佬的文章。本文的基本思路与大佬相同,但偏向于思路分析,还有对分类简化的尝试。
题意
原题链接
给定一棵树,每个点的点权为经过该点的直径数量,求树的点权和或点权平方和。
分析&Code 1
先放一张图,明确一下叫法:
深度为 2 的节点,本文中称为 top 节点,top 节点的子树简称为 top 子树,文中提到的“最长链”“次长链”的一段为根节点。
题目很直接,我们也有个直接的想法:我只要求出所有的直径,显示地加,一定可以做(废话,\(O(n^2)\),滚吧)。所以说问题就转变为:怎么确定直径和怎么方便地加。
第二问我会!树上差分!只要找到直径的两个端点,他们之间的链必然是直径,拿下!
所以其实只有一个问题:怎么确定直径。
我们可以想到直径的一个性质:直径必然过中点(如果有两个中点,则都经过)。所以我们考虑换根,以(其中一个)中心为根建树。这样有两个好处:直径必然过根,不用写 LCA 了;直径必然由从根出发的两条最长链(一个中心时)或一条最长链加上一条次长链(两个中心时)组成。于是我们也得到了确定直径端点的方法:换根后,深度等于直径一半或直径一半减一(分别对应最长链端点、次长链端点)的叶节点。
那么现在我们有了一个朴素的方法:换根,dfs 算一个深度,\(O(n^2)\) 扫,差分算贡献。在此基础上,我们发现:一个中心时,过某个直径端点,可以与所有不在和同一 top 子树下的其他直径端点形成一条直径。其实很直观,这其实就是两条最长链;两个中心时,有且仅有一个 top 子树下会有最长链(否则直径由两条最长链组成,应只有一个中心),直径由该 top 子树中的最长链端点和所有其他 top 子树的次长链端点组成。
于是,我们可以统计下所有 top 子树中的最长链端点数 \(cnt_1(u)\) 和次长链端点数 \(cnt_2(u)\)(下文中 \(cnt_1(rt)\) 表示 \(\sum cnt_1(u)\),\(cnt_2(rt)\) 同理),进行如下操作:
- 若只有一个中心,在所有最长链端点 \(u\) 对应的差分数组上加上 \(cnt_1(rt)-cnt_1(u)\);
- 若有两个中心,在所有最长链端点 \(u\) 对应的差分数组上加上 \(cnt_2(rt)\),次长链加 \(cnt_1(rt)\)。
然后差分统计答案即可。
最后,由于根节点在每条直径都被加了两次,最后记得乘上 2 的乘法逆元(致直接除调了一下午的我)。
Code
(以下是合并 \(cnt_1(u),cnt_2(u)\) 的方法)
但是但是,还要判断叶子是在是太麻烦了,一堆数组给自己都整晕了!!!有没有什么简单又强势的办法?有的兄弟有的。
我们发现,有两个中点时,由于两个中点都必然经过,所以它们就是一个点!用兄弟儿子法存树,直接删一个点,全部按只有一个中心做!
但是,我写的前向星怎么办(比如本蒟蒻)?好办!
把两个中心之间的边断了(不让走就行了,并且只有一个中心时就不存在这条边),整成两棵树,这时候两树的最长链长度相等,遍历所有点,在最长链端点对应的差分数组上加上另一棵树的 \(cnt(rt)\) 即可。需要注意:这么写只有在有一个中心时根会重复加,需要判断一下。
复杂度 \(O(n)\)。
Code 2(对应上面的第二种方法)
#include <iostream>
#include <cctype>
#include <cstdio>
#include <climits>typedef long long ll;int fr() {int x=0,f=1;char c=getchar();while(!isdigit(c)) {if(c=='-') f=-1;c=getchar();} while(isdigit(c)) {x=(x<<3)+(x<<1)+(c^48);c=getchar();}return x*f;
}const int maxn=5e6+100;
const int M=998244353;int head[maxn],tot;struct edge{int v,nxt;
}e[maxn*2];void ade(int u,int v) {e[++tot]={v,head[u]};head[u]=tot;
};int n,k;int d1[maxn],d2[maxn];void dfs1(int u,int f) {for(int i = head[u]; i; i=e[i].nxt) {int v=e[i].v;if(v==f) continue;dfs1(v,u);if(d1[v]+1>d1[u]) {d2[u]=d1[u];d1[u]=d1[v]+1;}else if(d1[v]+1>d2[u]) d2[u]=d1[v]+1;}
}int up[maxn];void dfs2(int u,int f) {for(int i = head[u]; i; i=e[i].nxt) {int v=e[i].v;if(f==v) continue;up[v]=up[u]+1;if(d1[v]+1!=d1[u]) up[v]=std::max(up[v],d1[u]+1);else up[v]=std::max(up[v],d2[u]+1);dfs2(v,u);}
}int rt[2],min_l=INT_MAX;void getr() {dfs1(1,0);dfs2(1,0);for(int i = 1; i <= n; i++) {if(std::max(d1[i],up[i])<min_l) {min_l=std::max(d1[i],up[i]);rt[0]=rt[1]=i;}else if(std::max(d1[i],up[i])==min_l) rt[1]=i;}
}
//就是和求重心类似的方式,但是不知道为什么机房佬都不认可int cnt[maxn],top[maxn];
int dep[maxn],fa[maxn];
//fa:每个top节点对应的树的编号void getd(int u,int f) {dep[u]=dep[f]+1;for(int i = head[u]; i; i=e[i].nxt) {int v=e[i].v;if(v==f||v==rt[0]||v==rt[1]) continue;//断边就是这个意思getd(v,u);}
}void count(int u,int f,int tp,int root) {top[u]=tp;if(dep[u]==min_l) cnt[tp]++,cnt[root]++;for(int i = head[u]; i; i=e[i].nxt) {int v=e[i].v;if(v==f||v==rt[0]||v==rt[1]) continue;count(v,u,tp,root);}
}void init() {for(int i=head[rt[0]]; i ; i=e[i].nxt) {int v=e[i].v;if(v==rt[1]) continue;count(v,rt[0],v,rt[0]);fa[v]=0;}if(rt[0]==rt[1]) return;for(int i=head[rt[1]]; i ; i=e[i].nxt) {int v=e[i].v;if(v==rt[0]) continue;count(v,rt[1],v,rt[1]);fa[v]=1;}
}ll d[maxn];
ll ans;void sol(int u,int f) {for(int i = head[u]; i; i=e[i].nxt) {int v=e[i].v;if(v==f||v==rt[0]||v==rt[1]) continue;sol(v,u);d[u]+=d[v];if(d[u]>M||d[u]<0) d[u]%=M;if(d[u]<0) d[u]+=M;}if(rt[0]==rt[1]&&u==rt[1]) d[u]*=(M+1)>>1,d[u]%=M;//2在模99824353意义下的乘法逆元比较特殊,其他的不知道可不可以ans+=d[u]*(k==1?1:d[u]);if(ans>M||ans<0) ans%=M;if(ans<0) ans+=M;
}int main() {n=fr(),k=fr();if(n<=2) {printf("%d\n",n);return 0;}//注意:这样写小于等于二会输出0,虽然没有这个数据,但是写一下for(int i = 1; i <= n-1; i++) {int u=fr(),v=fr();ade(u,v);ade(v,u);}getr();if(rt[0]==rt[1]) min_l++;getd(rt[0],0);getd(rt[1],0);init();for(int i = 1; i <= n; i++) {if(dep[i]==min_l) {d[i]+=cnt[rt[fa[top[i]]^1]];if(rt[0]==rt[1]) d[i]-=cnt[top[i]];}}sol(rt[0],0);if(rt[1]!=rt[0]) sol(rt[1],0);printf("%lld\n",ans);return 0;
}
闲话
蒟蒻经历了漫长的鏖战终于胜利,本篇算是写一些自己的理解了。
如果觉得有用,点个赞吧!