読者です 読者をやめる 読者になる 読者になる

mallet CRF の確率出力

Java

状態の番号がよくわからんかったという話。
参考は http://mallet.cs.umass.edu/fst.php

// 各クラスの import や crfの宣言は済んでいるということにして...
CRFTrainerByLabelLikelihood trainer = new CRFTrainerByLabelLikelihood(crf);
trainer.setGaussianPriorVariance(10.0);
trainer.train(trainingData);

// transducer は何か便利なwrapperらしい
transducer = trainer.getTransducer();

// SumLattice は確率出力用のクラス(?)
// input はデータの系列、本例ではinput末尾のラベルが予測したいとする
// もちろん番号を弄れば好きな位置の予測分布が得られます
SumLattice latt = new SumLatticeDefault(transducer, input);
Sequence result = transducer.transduce(input);

// stnum は状態数(候補となるラベルの数)
int stnum = transducer.numStates();
String names[] = new String[stnum];
double probs[] = new double[stnum];
for (int i = 0; i < stnum; i++) {
    names[i] = transducer.getState(i).getName();
    // 次に予想される周辺確率の出力(resultの末尾で予測される各ラベルの確率)
    // result.size() の位置を知ればinput末尾の予測分布を知ることができる、のが大事
    probs[i] = latt.getGammaProbability(result.size(), transducer.getState(i));
}

// input末尾における各ラベルの予測確率
System.out.println("[Labels]" + Arrays.toString(names));
System.out.println("[Probs]" + Arrays.toString(probs));

// input末尾について予測されるラベル
System.out.println("[Predict]" + result.get(result.size()-1).toString());

めでたしめでたし。