AOJ 1291,PKU 3946 Search of Concatenated Strings

問題概要

n個の文字列が与えられる.
このn個の文字列の順列がm個の文字列を繋げたものに何個含まれるか答えよ.
ただし,全く同じ部分の文字列を2度以上カウントしてはいけない.

解法

動的計画法
説明が難しいが,ある場所まで見たときに,その前までにちょうどどのような選び方が表われたのかを考えると,O(n*2^n*文字の長さ)になる.

実装(C++)

#include <algorithm>
#include <vector>
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;

typedef long long int lli;

#define REP(i,x) for(int i=0;i<(int)(x);i++)
#define rep(i,s,x) for(int i=s;i<(int)(x);i++)
#define FOR(i,c) for(__typeof((c).begin())i=(c).begin();i!=(c).end();i++)
#define RREP(i,x) for(int i=(x);i>=0;i--)
#define RFOR(i,c) for(__typeof((c).rbegin())i=(c).rbegin();i!=(c).rend();i++)
long long mod_pow(long long a,long long n,long long M){
	if(n==0)return 1%M;
	long long res=mod_pow(a*a%M,n/2,M);
	if(n&1)res=res*a%M;
	return res;
}
const int BASE=293;
const int MOD=1000000009;
int match[5002];
void rabin_karp(const char *exp,const char *find,int matched){
	int m=strlen(find),i,n=strlen(exp);
	int b=mod_pow(BASE,m-1,MOD);
	int hash_f=0,hash_e=0;
	if(n<m)return;
	for(i=0;i<m;i++){
		hash_f=((long long)hash_f*BASE)%MOD;
		hash_f=(hash_f+find[i])%MOD;
		hash_e=((long long)hash_e*BASE)%MOD;
		hash_e=(hash_e+exp[i])%MOD;
	}
	i=0;
	while(1){
		if(hash_f==hash_e){
			match[i]|=(1<<matched);
		}
		if(i>=n-m)break;
		hash_e=(hash_e+MOD-(long long)b*exp[i]%MOD)%MOD;
		hash_e=((long long)hash_e*BASE+exp[i+m])%MOD;
		i++;
	}
}
int n,m;
char find_str[12][100];
char expression[5002];
int exp_len;
int find_len[12];
char dp[5200][1<<12];
void input(int m){
	int k=0;
	int t,cnt=0;
	for(;cnt<=m;){
		t=getchar();
		if(islower(t))expression[k++]=t;
		if(t=='\n')cnt++;
	}
	exp_len=k;
}
int main(){
	for(;~scanf("%d%d",&n,&m);){
		if(n==0&&m==0)break;
		int i;
		for(i=0;i<n;i++){
			scanf("%s",find_str[i]);
		}
		input(m);
		memset(match,0,sizeof(int)*(exp_len+1));
		memset(dp,0,sizeof(dp));
		for(i=0;i<n;i++){
			rabin_karp(expression,find_str[i],i);
			find_len[i]=strlen(find_str[i]);
		}
		for(int i=0;i<exp_len;i++){
			dp[i][0]=1;
			for(int j=0;j<n;j++){
				if((match[i]>>j)&1){
					for(int k=0;k<(1<<n);k++){
						if((k>>j)&1)continue;
						if(dp[i][k]){
							dp[i+find_len[j]][k|(1<<j)]=1;
						}
					}
				}
			}
		}
		int ans=0;
		for(int i=0;i<=exp_len;i++){
			ans+=dp[i][(1<<n)-1];
		}
		cout<<ans<<endl;
	}
}