ABC299 F問題 Square Subsequence

atcoder.jp

部分文字列の数え上げ

まず、部分文字列を数える方法を考えてみます。文字列の長さを N文字とすると、各文字に対して選ぶか選ばないかの2択なので、部分文字列の選び方は 2^N 通りあります。
しかし、これは同じ部分文字列になる選び方が複数あるとき、それらすべてを重複カウントしています。
重複カウントを避けるために、可能な限り左側の文字から選んでいく動的計画法(DP)を考えます。i文字目を最後の文字とする部分文字列の数を dp(i) と表します。
1文字目の文字として'a'から'z'の各文字が最初に現れる位置 i dp(i)=1とします。つづいて i 文字目の後ろにさらに1文字を追加すると考えて、 i+1文字目以降の 'a'から'z'の各文字が最初に現れる位置 k dp(k) += dp(i)と加算します。これを i = 1,2,\cdots,Nとループして dp(i)を確定していきます。最終的に、 \sum dp(i)が重複のない部分文字列の数になります。

繰り返しになる部分文字列の数え上げ

同じ文字列の繰り返し TTの形の部分文字列を探します。先頭からの部分文字列と途中からの部分文字列が一致する場合の数を求めたいですが、「途中」がどの位置からなのかはいろいろありえます。
そこで、最初の T の最後の文字が K文字目である場合に限定します。このとき、2つ目の T K+1文字目からと、探し始める位置を固定できます。文字列 Sの最初から K文字目までを S_1 K+1文字目から最後までを S_2として、 S_1の部分文字列のうち S_2の部分文字列にもなる文字列の数を探します。
例えば、
 S_1 = abcabc
 S_2 = cbacbac
のケースを考えてみます。 S_1 dp(3) は3文字目の"c"を最後に選ぶ部分文字列の数で、 "c", "ac", "bc", "abc" の4つがあることから  dp(3)=4 になります。これに対応する S_2 の"c" の文字の位置は、"c" については 1文字目、"ac", "bc" については 4文字目、"abc"については7文字目になります。これらは数える上で区別する必要があるので、 S_1 S_2の共通部分文字列でS_1ではi文字目が最後の文字、 S_2ではj文字目が最後の文字になる文字列の数を動的計画法 dp(i,j)と表すことにします。
文字列 "c" に対応して dp(3,1) = 1、文字列"ac","bc" に対応して dp(3,4) = 2、文字列"abc"に対応してdp(3,7) = 1 になります。これにさらに一文字追加することを考えます。
"c" に "a"を追加すると "ca" になり、これは dp(4,3) に対応するため、 dp(4,3) += dp(3,1) と加算します。"b"を追加すると dp(5,2) += dp(3,1) と加算します。"c"を追加すと dp(6,4) += dp(3,1)と加算します。
"ac"または"bc"に一文字追加するケースは、"a"を追加すると dp(4,6) += dp(3,4)、"b"を追加すると dp(5,5) += dp(3,4)、"c"を追加すると dp(6,7) += dp(3,4)と加算します。
"abc" はdp(3,7) であり、 S_2 の最後の文字に達しているため文字を追加することができません。
このように計算して、 \sum_j dp(K,j) の和が最初の T の最後の文字が K文字目である場合の数になります。
 K=1,2,\cdots,N のループで和を計算することで答えが得られます。


コード例 (Julia)

_MOD = 998244353

function solve()

  S = readline()
  N = length(S)

  ans = 0

  for K = 1:N
    S1 = SubString(S,1,K)
    S2 = SubString(S,K+1,N)

    # 'a'から'z'が文字列中の何文字目か
    d1 = Dict(k => Int[] for k = 'a':'z')
    d2 = Dict(k => Int[] for k = 'a':'z')

    for j = 1:K
      c1 = S1[j]
      push!(d1[c1], j)
    end
    for j = 1:N-K
      c2 = S2[j]
      push!(d2[c2], j)
    end

    # 各文字の数
    n1 = Dict()
    n2 = Dict()
    for c = 'a':'z'
      n1[c] = length(d1[c])
      n2[c] = length(d2[c])
    end

    dp = zeros(Int,K,N-K)

    # 一文字目
    for c = 'a':'z'
      if n1[c] == 0 || n2[c] == 0
        continue
      end
      dp[ d1[c][1], d2[c][1] ] = 1
    end

    # 二文字目以降
    for i = 1:K-1
      for j = 1:N-K-1
        if dp[i,j] == 0
          continue
        end

        for c = 'a':'z'
          p1 = -1
          p2 = -1
          for x = 1:n1[c]
            if d1[c][x] > i
              p1 = d1[c][x]
              break
            end
          end
          for x = 1:n2[c]
            if d2[c][x] > j
              p2 = d2[c][x]
              break
            end
          end

          if p1 > 0 && p2 > 0
            dp[p1,p2] += dp[i,j]
            dp[p1,p2] %= _MOD
          end
        end
      end
    end

    for j = 1:N-K
      ans += dp[K,j]
      ans %= _MOD
    end

  end

  println(ans)

end  # function solve

# main
solve()