1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| #include<iostream> #include<cstdio> #define ll long long using namespace std; const int N=2e3+10,M=4e3+10; template<class T> inline void read(T &x) { x=0;bool f=0; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9')x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); if(f)x=-x; } template<class T,class ...T1> inline void read(T &x,T1 &...x1) { read(x),read(x1...); } int n,m; int tot,head[N],ver[M],e[M],ne[M]; int si[N]; ll f[N][N]; inline void add(int u,int v,int w) { ver[++tot]=v; e[tot]=w; ne[tot]=head[u]; head[u]=tot; } void dfs(int x,int fa) { si[x]=1; for(int i=head[x];i;i=ne[i]) { int y=ver[i]; if(y==fa)continue; dfs(y,x); for(int j=min(si[x]+si[y],m);~j;j--) for(int k=max(j-si[x],0);k<=min(si[y],j);k++) { ll cnt=k*(m-k)+(si[y]-k)*(n-m+k-si[y]); f[x][j]=max(f[x][j],f[x][j-k]+f[y][k]+cnt*e[i]); } si[x]+=si[y]; } } int main() { read(n,m); m=min(m,n-m); for(int i=1;i<n;i++) { int u,v,w; read(u,v,w); add(u,v,w); add(v,u,w); } dfs(1,0); printf("%lld\n",f[1][m]); return 0; }
|