この本を読みながら例題(線形クラス分類、
パーセプトロン、主形式)を
Perlで実装。ノイズを入れてないのに正答率が97%どまりなのは何故に。
use strict;
use warnings;
use Data::Dumper;
my $trainingDataContainer = new DataContainer($ARGV[0]);
my $testDataContainer = new DataContainer($ARGV[1]);
my $perceptron = new Perceptron(1, $trainingDataContainer);
$perceptron->learn();
my @dataSeq = $testDataContainer->dataSeq();
my $testDataCount = @dataSeq;
my $matchCount = 0;
for my $data (@dataSeq){
my @input = $data->input();
my $ans = $data->output();
my $output = $perceptron->output(\@input);
if($output * $ans > 0){
$matchCount++;
}
}
print "$matchCount / $testDataCount\n";
$perceptron->dataSeq(undef);
print Dumper($perceptron);
{
package Data;
sub new{
my ($this, $input, $output) = @_;
my $obj = {
input => $input,
output => $output
};
return bless $obj, $this;
}
sub input{
my $this = shift;
$this->{input} = shift if @_;
return wantarray ? @{$this->{input}} : $this->{input};
}
sub output{
my $this = shift;
$this->{output} = shift if @_;
return $this->{output};
}
sub inputNorm{
my $this = shift;
my $square = 0;
for my $elem (@{$this->{input}}){
$square += ($elem * $elem);
}
return sqrt($square);
}
}
{
package DataContainer;
sub new{
my ($this, $filename) = @_;
my $dataSeq = [];
my $tokenLines = readFile($filename);
for my $lines (@$tokenLines){
my $output = shift @$lines;
my $data = new Data($lines, $output);
push(@$dataSeq, $data)
}
my $obj = {
dataSeq => $dataSeq
};
return bless $obj, $this;
}
sub dataSeq{
my ($this, $indx) = @_;
if(defined($indx)){
return ${$this->{dataSeq}}[$indx]
}
return wantarray ? @{$this->{dataSeq}} : $this->{dataSeq};
}
sub readFile{
my $filename = shift;
open(IN, $filename);
my @lines = <IN>;
close(IN);
my @tokenLines = ();
for my $line (@lines){
my @tokens = split(/,/, $line);
push(@tokenLines, [@tokens]);
}
return [@tokenLines];
}
}
{
package Perceptron;
sub new{
my ($this, $learningRate, $dataContainer) = @_;
my $dataSeq = $dataContainer->dataSeq();
my $rSuqare = $$dataSeq[0]->inputNorm();
for(my $i = 0; $i < @{$dataSeq}; $i++){
my $norm = $$dataSeq[$i]->inputNorm();
$rSuqare = $norm if($rSuqare < $norm);
}
my $obj = {
learningRate => $learningRate,
rSquare => $rSuqare,
dataSeq => $dataSeq,
weight => [0, 0, 0]
};
return bless $obj, $this;
}
sub dataSeq{
my $this = shift;
$this->{dataSeq} = shift if @_;
return $this->{dataSeq};
}
sub output{
my ($this, $input) = @_;
my $sum = ${$this->{weight}}[2];
for(my $i = 0; $i < @$input; $i++){
$sum += ${$input}[$i] * ${$this->{weight}}[$i];
}
return $sum;
}
sub learn{
my ($this) = @_;
for my $data (@{$this->{dataSeq}}){
my @input = $data->input();
my $ans = $data->output();
my $output = $this->output(\@input);
if($output * $ans > 0){
next;
}
for(my $i = 0; $i < @input; $i++){
${$this->{weight}}[$i] += $this->{learningRate} * $input[$i] * $ans;
}
my $lastIndx = $#{$this->{weight}};
${$this->{weight}}[$lastIndx] += $this->{learningRate} * $ans * $this->{rSquare};
}
}
}