import numpy as np weight =[0,0,0] rate = 0.1 def sign(value): """符号函数""" if value>0: return 1 elif value<0: return -1 else: return 0 def crDataSet(datafile): """创建数据集""" result = [] with open(datafile,'r') as file: data = file.read().splitlines() for line in data: temp = [] items = line.split(',') temp.append(float(items[0])) temp.append(float(items[2])) if items[4] =='Iris-setosa': temp.append(1) else: temp.append(-1) result.append(temp) return result def dotProduct(item): """计算点积""" sum = weight[0]*1 for i in range(2): sum += item[i]*weight[i+1] return sum def upWeight(item,estimate): """更新权重""" weight[0] += rate*(item[2]-estimate) for i in range(2): weight[i+1] += rate*(item[2]-estimate)*item[i] def train(): error =1 #是否错误 while error: error = 0 for item in train_data: outvalue = sign(dotProduct(item)) while outvalue !=item[2]: error = 1 upWeight(item,outvalue) outvalue = sign(dotProduct(item)) def test(): error = 0 for item in test_data: #用测试集集数据进行测试 if sign(dotProduct(item)) != item[2]: error +=1 print("错误数:",error) print(weight) #查看权重 def disp(): import matplotlib.pyplot as plt plt.xlabel('花萼长度',fontproperties = 'SimHei',fontsize = 12) plt.ylabel('花瓣长度',fontproperties = 'SimHei',fontsize = 12) positive_x1 = [train_data[i][0] for i in range(60) if train_data[i][2] == 1] positive_x2 = [train_data[i][1] for i in range(60) if train_data[i][2] == 1] negetive_x1 = [train_data[i][0] for i in range(60) if train_data[i][2] == -1] negetive_x2 = [train_data[i][1] for i in range(60) if train_data[i][2] == -1] plt.scatter(positive_x1,positive_x2,s=5,c='red') plt.scatter(negetive_x1,negetive_x2,s=5,c='blue') line_x = np.arange(0,10) line_y = (-weight[1]*line_x -weight[0]) / weight[2] plt.plot(line_x,line_y) plt.show() if __name__ =="__main__": train_data = crDataSet('train.data') test_data = crDataSet('test.data') train() test() disp() print(sign(dotProduct([2.2,3.3])))