#include #include #include #include #include #include #include using namespace std; typedef long long ll; typedef pair pll; const ll mod=1e9+7; inline ll read() { ll x=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } inline char getc() { char ch=getchar(); while(ch<'0'||ch>'1')ch=getchar(); return ch; } ll power(ll a,ll b,ll p) { ll r=1,w=a; while(b) { if(b&1)r=r*w%p; w=w*w%p,b>>=1; } return r%p; } ll to[400005],head[200005],nxt[400005],wei[400005],cnte; void add_edge(ll u,ll v) { to[++cnte]=v,nxt[cnte]=head[u],head[u]=cnte; to[++cnte]=u,nxt[cnte]=head[v],head[v]=cnte; } ll du[200005]; ll inv[200005],frac[200005]; ll dp[200005][2]; ll ans=0; void dfs(ll u,ll fa) { if(fa&&du[u]==1) { dp[u][0]=1; return; } ll sum0=0,sum1=0; for(int i=head[u];i;i=nxt[i]) { ll v=to[i],w=wei[i]; if(v==fa)continue; dfs(v,u); ll r0=dp[v][0],r1=dp[v][1]; if(w)r1+=r0,r0=0,r1%=mod; r0=r0*inv[du[v]-1]%mod,r1=r1*inv[du[v]-1]%mod; ans+=(r0*sum1+r1*sum0+r1*sum1)%mod*inv[du[u]-1]%mod; ans%=mod; sum0+=r0,sum1+=r1,sum0%=mod,sum1%=mod; } dp[u][0]=sum0,dp[u][1]=sum1; if(du[u]==1)ans+=sum1,ans%=mod; } void solve() { ll n=read(),k=read(); inv[0]=1,frac[0]=1; for(int i=1;i<=n;i++)inv[i]=power(i,mod-2,mod); for(int i=1;i<=n;i++)frac[i]=frac[i-1]*i%mod; for(int i=1;i