最小二乗法

Pythonの勉強がてら、「はじめての機械学習」のcで書かれたコードをPythonに力技で翻訳していく学習記録。

まずは最小2乗法。

式の解説はここがわかりやすい:http://szksrv.isc.chubu.ac.jp/lms/lms1.html

プログラムは標準入力の2値のセットから係数a0, a1を求めるもの。

#!/usr/bin/python
import sys

__TEXTLENGTH__ = 4096
#print __TEXTLENGTH__

text = []
xi = 0
yi = 0
sxi = 0
syi = 0
sxiyi = 0
sxi2 = 0
a0 = 0
a1 = 0
n = 0

line = sys.stdin.readline()
while line:
	#print line.split()
	nums = line.split()
	if len(nums) == 2:
		#print nums[0] + nums[1]
		xi = float(nums[0])
		yi = float(nums[1])
		sxi+= xi
		syi+= yi
		sxiyi+= xi*yi
		sxi2+= xi*xi
		n+=1
		#print str(sxi) + ':' + str(syi) + ':' + str(n)
	else:
		print 'invalid data: ' + line
	line = sys.stdin.readline()
	if len(line.split()) < 1:
		break

print 'end loop'

if n > 1:
	#print n 
	a0 = (sxi2*syi-sxiyi*sxi) / (n*sxi2-sxi*sxi)
	a1 = (n*sxiyi-sxi*syi) / (n* sxi2-sxi*sxi)
	print 'y='+str(a0)+'+'+str(a1)+'x'
else:
	print 'data is not sufficient'