部分文字列の数え上げ
まず、部分文字列を数える方法を考えてみます。文字列の長さを文字とすると、各文字に対して選ぶか選ばないかの2択なので、部分文字列の選び方は 通りあります。
しかし、これは同じ部分文字列になる選び方が複数あるとき、それらすべてを重複カウントしています。
重複カウントを避けるために、可能な限り左側の文字から選んでいく動的計画法(DP)を考えます。文字目を最後の文字とする部分文字列の数を と表します。
1文字目の文字として'a'から'z'の各文字が最初に現れる位置でとします。つづいて 文字目の後ろにさらに1文字を追加すると考えて、文字目以降の 'a'から'z'の各文字が最初に現れる位置でと加算します。これをとループしてを確定していきます。最終的に、が重複のない部分文字列の数になります。
繰り返しになる部分文字列の数え上げ
同じ文字列の繰り返しの形の部分文字列を探します。先頭からの部分文字列と途中からの部分文字列が一致する場合の数を求めたいですが、「途中」がどの位置からなのかはいろいろありえます。
そこで、最初の の最後の文字が文字目である場合に限定します。このとき、2つ目のは文字目からと、探し始める位置を固定できます。文字列の最初から文字目までを、文字目から最後までをとして、の部分文字列のうちの部分文字列にもなる文字列の数を探します。
例えば、
のケースを考えてみます。の は3文字目の"c"を最後に選ぶ部分文字列の数で、 "c", "ac", "bc", "abc" の4つがあることから になります。これに対応する の"c" の文字の位置は、"c" については 1文字目、"ac", "bc" については 4文字目、"abc"については7文字目になります。これらは数える上で区別する必要があるので、との共通部分文字列ででは文字目が最後の文字、では文字目が最後の文字になる文字列の数を動的計画法のと表すことにします。
文字列 "c" に対応して、文字列"ac","bc" に対応して、文字列"abc"に対応して になります。これにさらに一文字追加することを考えます。
"c" に "a"を追加すると "ca" になり、これは に対応するため、 と加算します。"b"を追加すると と加算します。"c"を追加すとと加算します。
"ac"または"bc"に一文字追加するケースは、"a"を追加すると、"b"を追加すると、"c"を追加するとと加算します。
"abc" は であり、 の最後の文字に達しているため文字を追加することができません。
このように計算して、 の和が最初の の最後の文字が文字目である場合の数になります。
のループで和を計算することで答えが得られます。
コード例 (Julia)
_MOD = 998244353 function solve() S = readline() N = length(S) ans = 0 for K = 1:N S1 = SubString(S,1,K) S2 = SubString(S,K+1,N) # 'a'から'z'が文字列中の何文字目か d1 = Dict(k => Int[] for k = 'a':'z') d2 = Dict(k => Int[] for k = 'a':'z') for j = 1:K c1 = S1[j] push!(d1[c1], j) end for j = 1:N-K c2 = S2[j] push!(d2[c2], j) end # 各文字の数 n1 = Dict() n2 = Dict() for c = 'a':'z' n1[c] = length(d1[c]) n2[c] = length(d2[c]) end dp = zeros(Int,K,N-K) # 一文字目 for c = 'a':'z' if n1[c] == 0 || n2[c] == 0 continue end dp[ d1[c][1], d2[c][1] ] = 1 end # 二文字目以降 for i = 1:K-1 for j = 1:N-K-1 if dp[i,j] == 0 continue end for c = 'a':'z' p1 = -1 p2 = -1 for x = 1:n1[c] if d1[c][x] > i p1 = d1[c][x] break end end for x = 1:n2[c] if d2[c][x] > j p2 = d2[c][x] break end end if p1 > 0 && p2 > 0 dp[p1,p2] += dp[i,j] dp[p1,p2] %= _MOD end end end end for j = 1:N-K ans += dp[K,j] ans %= _MOD end end println(ans) end # function solve # main solve()