線形クラス分類

サポートベクターマシン入門

サポートベクターマシン入門

この本を読みながら例題(線形クラス分類、パーセプトロン、主形式)をPerlで実装。ノイズを入れてないのに正答率が97%どまりなのは何故に。

use strict;
use warnings;
use Data::Dumper;

my $trainingDataContainer = new DataContainer($ARGV[0]);
my $testDataContainer = new DataContainer($ARGV[1]);
my $perceptron = new Perceptron(1, $trainingDataContainer);

$perceptron->learn();

my @dataSeq = $testDataContainer->dataSeq();
my $testDataCount = @dataSeq;
my $matchCount = 0;
for my $data (@dataSeq){
 my @input = $data->input();
 my $ans = $data->output();
 my $output = $perceptron->output(\@input);
 
 if($output * $ans > 0){
  $matchCount++;
 }
}
print "$matchCount / $testDataCount\n";

$perceptron->dataSeq(undef);
print Dumper($perceptron);

####################################################
# 事例クラス
{
 package Data;
 
 sub new{
  my ($this, $input, $output) = @_;
  my $obj = {
   input => $input,
   output => $output
  };
  return bless $obj, $this;
 }
 
 sub input{
  my $this = shift;
  $this->{input} = shift if @_;
  return wantarray ? @{$this->{input}} : $this->{input};
 }
 
 sub output{
  my $this = shift;
  $this->{output} = shift if @_;
  return $this->{output};
 }
 
 sub inputNorm{
  my $this = shift;
  my $square = 0;
  for my $elem (@{$this->{input}}){
   $square += ($elem * $elem);
  }
  return sqrt($square);
 }
}

{
 package DataContainer;
 
 sub new{
  my ($this, $filename) = @_;
  my $dataSeq = [];
  
  my $tokenLines = readFile($filename);
  for my $lines (@$tokenLines){
   my $output = shift @$lines;
   my $data = new Data($lines, $output);
   push(@$dataSeq, $data)
  }
  
  my $obj = {
   dataSeq => $dataSeq
  };
  return bless $obj, $this;
 }
 
 sub dataSeq{
  my ($this, $indx) = @_;
  if(defined($indx)){
   return ${$this->{dataSeq}}[$indx]
  }
  return wantarray ? @{$this->{dataSeq}} : $this->{dataSeq};
 }
 
 sub readFile{
  my $filename = shift;
  open(IN, $filename);
  my @lines = <IN>;
  close(IN);
  
  my @tokenLines = ();
  for my $line (@lines){
   my @tokens = split(/,/, $line);
   push(@tokenLines, [@tokens]);
  }
  return [@tokenLines];
 }
}

####################################################
# パーセプトロン
{
 package Perceptron;
 
 sub new{
  my ($this, $learningRate, $dataContainer) = @_;
  
  my $dataSeq = $dataContainer->dataSeq();
  my $rSuqare = $$dataSeq[0]->inputNorm();
  for(my $i = 0; $i < @{$dataSeq}; $i++){
   my $norm = $$dataSeq[$i]->inputNorm();
   $rSuqare = $norm if($rSuqare < $norm);
  }
  
  my $obj = {
   learningRate => $learningRate,
   rSquare => $rSuqare,
   dataSeq => $dataSeq,
   weight => [0, 0, 0]
  };
  
  return bless $obj, $this;
 }
 
 sub dataSeq{
  my $this = shift;
  $this->{dataSeq} = shift if @_;
  return $this->{dataSeq};
 }
 
 sub output{
  my ($this, $input) = @_;  
  my $sum = ${$this->{weight}}[2];
  for(my $i = 0; $i < @$input; $i++){
   $sum += ${$input}[$i] * ${$this->{weight}}[$i];
  }
  
  return $sum;
 }
 
 sub learn{
  my ($this) = @_;
  
  for my $data (@{$this->{dataSeq}}){
   my @input = $data->input();
   my $ans = $data->output();
   my $output = $this->output(\@input);
   if($output * $ans > 0){
    next;
   }
   
   for(my $i = 0; $i < @input; $i++){
    ${$this->{weight}}[$i] += $this->{learningRate} * $input[$i] * $ans;
   }
   
   my $lastIndx = $#{$this->{weight}};
   ${$this->{weight}}[$lastIndx] += $this->{learningRate} * $ans * $this->{rSquare};
  }
 }
}