ABC207 D - Congruence Points を行列を使って解く

解説に載っている解法より計算量も悪いし実装も重たいのでアレ。

一般には3点の行き先を決めたら線形変換+平行移動が決まる。 そこで最初の3点の行き先を全部試して、それぞれについてその変換を表す行列を求めてやって、実際にその行列でうまくいくか確かめる方針。

失敗実装例

コンテスト中に投稿したのが概ね以下のコードだが、一個のテストケースでWAになってしまった。 原因は分かるだろうか?

#include <iostream>
#include <set>
#include <boost/rational.hpp>
using boost::rational;
using namespace std;
#define rep(i, n) for(int i = 0; i < (n); i ++)

int n;
typedef pair<int, int> P;
P ps[101];
P qs[101];
typedef rational<int> rat;

// 3 次正方行列を扱うクラス
class Matrix { /* 省略 */ }

int pow2(int x) { return x * x;  }

bool ok(Matrix m) {
    set<P> xys;
    set<P> zws;
    rep(i, n) {
        rat w_ = m.m11_ * ps[i].first + m.m12_ * ps[i].second + m.m13_;
        rat z_ = m.m21_ * ps[i].first + m.m22_ * ps[i].second + m.m23_;
        if (w_.denominator() != 1 || z_.denominator() != 1) return false;
        int w = w_.numerator(), z = z_.numerator();
        P p(w, z);
        xys.emplace(p);
    }
    rep(i, n) zws.emplace(qs[i]);
    return xys == zws;
}

bool solve() {
    if (n == 1) {
        return true;
    } else if (n == 2) {
        int d1 = pow2(ps[0].first - ps[1].first) + pow2(ps[0].second - ps[1].second);
        int d2 = pow2(qs[0].first - qs[1].first) + pow2(qs[0].second - qs[1].second);
        return d1 == d2;
    }
    rep(i, n) rep(j, n) rep(k, n) {
        if (i == j || j == k || k == i) continue;
        rat x0, y0, x1, y1, x2, y2;
        rat w0, z0, w1, z1, w2, z2;
        x0 = ps[0].first, y0 = ps[0].second;
        x1 = ps[1].first, y1 = ps[1].second;
        x2 = ps[2].first, y2 = ps[2].second;
        w0 = qs[i].first, z0 = qs[i].second;
        w1 = qs[j].first, z1 = qs[j].second;
        w2 = qs[k].first, z2 = qs[k].second;
        Matrix m1(x0, x1, x2, y0, y1, y2, 1, 1, 1);
        Matrix m2(w0, w1, w2, z0, z1, z2, 1, 1, 1);
        rat d1 = determinant(m1);
        rat d2 = determinant(m2);
        if (d1 != 0 && d1 == d2) {
            Matrix m = m2 * m1.adjugate() / d1;
            if (m.m11_ == m.m22_ && m.m12_ == -m.m21_ && determinant(m) == 1) {
                if (ok(m)) return true;
            }
        }
    }
    return false;
}

int main()  {
    ios_base::sync_with_stdio(false); cin.tie(0);
    cin >> n;
    rep(i, n) {
        int x, y;
        cin >> x >> y;
        ps[i] = make_pair(x, y);
    }
    rep(i, n) {
        int x, y;
        cin >> x >> y;
        qs[i] = make_pair(x, y);
    }
    cout << (solve() ? "Yes" : "No") << endl;
}

成功実装例

先程のコードだとうまくいかないのは「ps[0], ps[1], ps[2]が一直線上に並んだ場合、行列m1が非正則になるから」である。

そこで全部が一直線上に並んでいる場合とそうでない場合を場合分けして、一直線上に並んでないときはそのようなps[0], ps[1], ps[l]をとってきて先程の方法をとる。 これでちゃんと通った。

#include <iostream>
#include <set>
#include <boost/rational.hpp>
using boost::rational;
using namespace std;
#define rep(i, n) for(int i = 0; i < (n); i ++)

int n;
typedef pair<int, int> P;
P ps[101];
P qs[101];
typedef rational<int> rat;

// 3 次正方行列を扱うクラス
class Matrix { /* 省略 */ };
int pow2(int x) { return x * x;  }
rat pow2(rat x) { return x * x;  }

bool ok(Matrix m) {
    set<P> xys;
    set<P> zws;
    rep(i, n) {
        rat w_ = m.m11_ * ps[i].first + m.m12_ * ps[i].second + m.m13_;
        rat z_ = m.m21_ * ps[i].first + m.m22_ * ps[i].second + m.m23_;
        if (w_.denominator() != 1 || z_.denominator() != 1) return false;
        int w = w_.numerator(), z = z_.numerator();
        P p(w, z);
        xys.emplace(p);
    }
    rep(i, n) zws.emplace(qs[i]);
    return xys == zws;
}

// ps[0], ps[1], ps[l]が一直線上にならないようなlを返す
// 全部一直線上に並んでいるなら-1を返す
int findindex(P ps[101]) {
    for (int l = 2; l < n; l ++) {
        rat x0, y0, x1, y1, x2, y2;
        x0 = ps[0].first, y0 = ps[0].second;
        x1 = ps[1].first, y1 = ps[1].second;
        x2 = ps[l].first, y2 = ps[l].second;
        Matrix m1(x0, x1, x2, y0, y1, y2, 1, 1, 1);
        if (determinant(m1) != 0) return l;
    }
    return -1;
}

bool solve_if_all_lie_a_line() {
    if (findindex(qs) >= 0) return false;
    rat ts[101];
    rat ss[101];
    int d = pow2(ps[1].first - ps[0].first) + pow2(ps[1].second - ps[0].second);
    int e = pow2(qs[1].first - qs[0].first) + pow2(qs[1].second - qs[0].second);
    rep(i, n) {
        if (ps[1].first - ps[0].first != 0) {
            ts[i] = (rat)(ps[i].first - ps[0].first) / (ps[1].first - ps[0].first);
        } else {
            ts[i] = (rat)(ps[i].second - ps[0].second) / (ps[1].second - ps[0].second);
        }
        if (qs[1].first - qs[0].first != 0) {
            ss[i] = (rat)(qs[i].first - qs[0].first) / (qs[1].first - qs[0].first);
        } else {
            ss[i] = (rat)(qs[i].second - qs[0].second) / (qs[1].second - qs[0].second);
        }
    }
    sort(ts, ts + n);
    sort(ss, ss + n);
    bool ok1 = true, ok2 = true;
    rep(i, n) {
        if (d * pow2(ts[i] - ts[0])  != e * pow2(ss[i] - ss[0])) ok1 = false;
    }
    reverse(ss, ss + n);
    rep(i, n) {
        if (d * pow2(ts[i] - ts[0]) != e * pow2(ss[i] - ss[0])) ok2 = false;
    }
    return ok1 || ok2;
}

bool solve() {
    if (n == 1) {
        return true;
    } else if (n == 2) {
        int d1 = pow2(ps[0].first - ps[1].first) + pow2(ps[0].second - ps[1].second);
        int d2 = pow2(qs[0].first - qs[1].first) + pow2(qs[0].second - qs[1].second);
        return d1 == d2;
    }
    int l = findindex(ps);
    if (l == -1) return solve_if_all_lie_a_line();
    rep(i, n) rep(j, n) rep(k, n) {
        if (i == j || j == k || k == i) continue;
        rat x0, y0, x1, y1, x2, y2;
        rat w0, z0, w1, z1, w2, z2;
        x0 = ps[0].first, y0 = ps[0].second;
        x1 = ps[1].first, y1 = ps[1].second;
        x2 = ps[l].first, y2 = ps[l].second;
        w0 = qs[i].first, z0 = qs[i].second;
        w1 = qs[j].first, z1 = qs[j].second;
        w2 = qs[k].first, z2 = qs[k].second;
        Matrix m1(x0, x1, x2, y0, y1, y2, 1, 1, 1);
        Matrix m2(w0, w1, w2, z0, z1, z2, 1, 1, 1);
        rat d1 = determinant(m1);
        rat d2 = determinant(m2);
        if (d1 == d2) {
            Matrix m = m2 * m1.adjugate() / d1;
            if (m.m11_ == m.m22_ && m.m12_ == -m.m21_ && determinant(m) == 1) {
                if (ok(m)) return true;
            }
        }
    }
    return false;
}

int main()  {
    /* 省略 */
}

計算量はnについて4重ループしているのでO(n^4)のはず。

通った提出は以下。

上記で省略したMatrixクラスの実装も上のリンクを参照。 Matrixクラスについてはmntone氏によるclass 3x3 Matrix for Visual C++ 2012 · GitHubのものを編集して使った。