Dr.nabeee3と日々のめも

詰まったところの解決策とか興味のあることを書いていくよてい

openFrameworksでlibSVM

openFrameworks上でlibSVMを実装してみました.
目指すところとしては,あらかじめ学習したデータセットでモデルを作っておき,リアルタイムに入ってくるデータセットのクラスを判別するという感じです.

この方の記事を参考にさせて頂きました.というかほとんどそのままです.
なんだか雲行きの怪しい雑記帖 週アレ(11) C++で始めるLibSVM

addonでofxSVMというのもあったんですが,今回の用途には適していないと感じたので普通のlibSVMのソースを落としてきました.

プロジェクトのsrcフォルダにsvm.hとsvm.cppを追加します.
ヘッダファイルはこんな感じ

- ofApp.h -
#pragma once

#include "ofMain.h"
#include "svm.h"

class node{
public:
    node(int label, double x, double y){
        this->x = x;
        this->y = y;
        this->label =label;
    }
    int label;
    double x, y;
};

class ofApp : public ofBaseApp{
	public:
		void setup();
		void update();
		void draw();
		
		void keyPressed(int key);
		void keyReleased(int key);
		void mouseMoved(int x, int y);
		void mouseDragged(int x, int y, int button);
		void mousePressed(int x, int y, int button);
		void mouseReleased(int x, int y, int button);
		void windowResized(int w, int h);
		void dragEvent(ofDragInfo dragInfo);
		void gotMessage(ofMessage msg);
    
    vector<node> samples;
    int sampleDataNum = 500;
};

cppファイルは抜粋するとこんな感じ

#include "ofApp.h"

//--------------------------------------------------------------
void ofApp::setup(){
    ofSetWindowShape(600, 400);
    
    int class1label = 1;
    int class2label = -1;
    
    for (int i=0; i < sampleDataNum; i++) {
        node n(class1label, ofRandom(100), ofRandom(60));
        samples.push_back(n);
    }
    for (int i=0; i < sampleDataNum; i++) {
        node n(class2label, ofRandom(100), ofRandom(60)+40);
        samples.push_back(n);
    }
}

//--------------------------------------------------------------
void ofApp::draw(){
    float mapX = ofGetWidth()/100;
    float mapY = ofGetHeight()/100;
    ofSetColor(255, 0, 0);
    for (int i=0; i < sampleDataNum; i++) {
        ofCircle(samples[i].x * mapX, ofGetHeight() - samples[i].y * mapY, 5);
    }
    ofSetColor(0, 0, 255);
    for (int i=sampleDataNum; i < sampleDataNum+sampleDataNum; i++) {
        ofCircle(samples[i].x * mapX, ofGetHeight() - samples[i].y * mapY, 5);
    }
}

//--------------------------------------------------------------
void ofApp::mousePressed(int x, int y, int button){
    svm_problem prob;
    svm_node* prob_vec;
    // size of training data
    prob.l = samples.size();
    // label of each training data
    prob.y = new double[prob.l];
    for (int i=0; i<samples.size(); ++i) {
        prob.y[i] = samples[i].label;
    }
    // vector of each training data
    prob_vec = new svm_node[prob.l * (2+1)];
    prob.x = new svm_node*[prob.l];
    for (int i=0; i < samples.size(); ++i) {
        prob.x[i] = prob_vec+i*3;
        prob.x[i][0].index = 1;
        prob.x[i][0].value = samples[i].x;
        prob.x[i][1].index = 2;
        prob.x[i][1].value = samples[i].y;
        prob.x[i][2].index = -1;
    }
    
    // parameter
    svm_parameter param;
    param.svm_type = C_SVC;
    param.kernel_type = LINEAR;
    param.C = 8096;
    param.gamma = 0.1;
    
    param.coef0 = 0;
    param.cache_size = 100;
    param.eps = 1e-3;
    param.shrinking = 1;
    param.probability = 0;
    
    param.degree = 3;
    param.nu = 0.5;
    param.p = 0.1;
    param.nr_weight = 0;
    param.weight_label = NULL;
    param.weight = NULL;
    
    // learning!
    cout << "Ready to train..." << endl;
    svm_model* model = svm_train(&prob, &param);
    cout << "Finished..." << endl;
    
    int correct_count = 0, wrong_count = 0;
    cout << "predict training samples..." << endl;
    for (int i=0; i < samples.size(); i++) {
        svm_node test[3];
        test[0].index = 1;
        test[0].value = samples[i].x;
        test[1].index = 2;
        test[1].value = samples[i].y;
        test[2].index = -1;
        // predict by libsvm
        const auto predict_label = static_cast<int>(svm_predict(model, test));
        // count result
        if (predict_label == samples[i].label) correct_count++;
        else wrong_count++;
    }
    cout << "done" << endl;
    cout << "correct: " << correct_count << endl;
    cout << "wrong: " << wrong_count << endl;
    cout << "accuracy: " << 100 * correct_count/(correct_count+wrong_count) << "%" << endl;
}

とりあえず動作確認したいだけなのでこんなもので.

今回はランダムでサンプルデータを作ってるんですが,ある一回のサンプルデータの分布がこんな感じでした.
f:id:nabeee3:20150212233134p:plain
ちなみにこのときのaccuracyは84%でした.

SVMはパラメータ調整が非常に大事とのことなので,そこら辺もちゃんと実装していかなければ.
けど今回はこのへんで.