`
dawuafang
  • 浏览: 1105045 次
文章分类
社区版块
存档分类
最新评论

java实现一元线性回归算法

 
阅读更多

网上看一个达人用java写的一元线性回归的实现,我觉得挺有用的,一些企业做数据挖掘不是用到了,预测运营收入的功能吗?采用一元线性回归算法,可以计算出类似的功能。直接上代码吧:

1、定义一个DataPoint类,对X和Y坐标点进行封装:

/**
 * File        : DataPoint.java
 * Author      : zhouyujie
 * Date        : 2012-01-11 16:00:00
 * Description : Java实现一元线性回归的算法,座标点实体类,(可实现统计指标的预测)
 */
package com.zyujie.dm;

public class DataPoint {

	/** the x value */
	public float x;

	/** the y value */
	public float y;

	/**
	 * Constructor.
	 * 
	 * @param x
	 *            the x value
	 * @param y
	 *            the y value
	 */
	public DataPoint(float x, float y) {
		this.x = x;
		this.y = y;
	}
}
2、下面是算法实现回归线:

/**
 * File        : DataPoint.java
 * Author      : zhouyujie
 * Date        : 2012-01-11 16:00:00
 * Description : Java实现一元线性回归的算法,回归线实现类,(可实现统计指标的预测)
 */
package com.zyujie.dm;

import java.math.BigDecimal;
import java.util.ArrayList;

public class RegressionLine // implements Evaluatable
{
	/** sum of x */
	private double sumX;

	/** sum of y */
	private double sumY;

	/** sum of x*x */
	private double sumXX;

	/** sum of x*y */
	private double sumXY;

	/** sum of y*y */
	private double sumYY;

	/** sum of yi-y */
	private double sumDeltaY;

	/** sum of sumDeltaY^2 */
	private double sumDeltaY2;

	/** 误差 */
	private double sse;

	private double sst;

	private double E;

	private String[] xy;

	private ArrayList listX;

	private ArrayList listY;

	private int XMin, XMax, YMin, YMax;

	/** line coefficient a0 */
	private float a0;

	/** line coefficient a1 */
	private float a1;

	/** number of data points */
	private int pn;

	/** true if coefficients valid */
	private boolean coefsValid;

	/**
	 * Constructor.
	 */
	public RegressionLine() {
		XMax = 0;
		YMax = 0;
		pn = 0;
		xy = new String[2];
		listX = new ArrayList();
		listY = new ArrayList();
	}

	/**
	 * Constructor.
	 * 
	 * @param data
	 *            the array of data points
	 */
	public RegressionLine(DataPoint data[]) {
		pn = 0;
		xy = new String[2];
		listX = new ArrayList();
		listY = new ArrayList();
		for (int i = 0; i < data.length; ++i) {
			addDataPoint(data[i]);
		}
	}

	/**
	 * Return the current number of data points.
	 * 
	 * @return the count
	 */
	public int getDataPointCount() {
		return pn;
	}

	/**
	 * Return the coefficient a0.
	 * 
	 * @return the value of a0
	 */
	public float getA0() {
		validateCoefficients();
		return a0;
	}

	/**
	 * Return the coefficient a1.
	 * 
	 * @return the value of a1
	 */
	public float getA1() {
		validateCoefficients();
		return a1;
	}

	/**
	 * Return the sum of the x values.
	 * 
	 * @return the sum
	 */
	public double getSumX() {
		return sumX;
	}

	/**
	 * Return the sum of the y values.
	 * 
	 * @return the sum
	 */
	public double getSumY() {
		return sumY;
	}

	/**
	 * Return the sum of the x*x values.
	 * 
	 * @return the sum
	 */
	public double getSumXX() {
		return sumXX;
	}

	/**
	 * Return the sum of the x*y values.
	 * 
	 * @return the sum
	 */
	public double getSumXY() {
		return sumXY;
	}

	public double getSumYY() {
		return sumYY;
	}

	public int getXMin() {
		return XMin;
	}

	public int getXMax() {
		return XMax;
	}

	public int getYMin() {
		return YMin;
	}

	public int getYMax() {
		return YMax;
	}

	/**
	 * Add a new data point: Update the sums.
	 * 
	 * @param dataPoint
	 *            the new data point
	 */
	public void addDataPoint(DataPoint dataPoint) {
		sumX += dataPoint.x;
		sumY += dataPoint.y;
		sumXX += dataPoint.x * dataPoint.x;
		sumXY += dataPoint.x * dataPoint.y;
		sumYY += dataPoint.y * dataPoint.y;

		if (dataPoint.x > XMax) {
			XMax = (int) dataPoint.x;
		}
		if (dataPoint.y > YMax) {
			YMax = (int) dataPoint.y;
		}

		// 把每个点的具体坐标存入ArrayList中,备用

		xy[0] = (int) dataPoint.x + "";
		xy[1] = (int) dataPoint.y + "";
		if (dataPoint.x != 0 && dataPoint.y != 0) {
			System.out.print(xy[0] + ",");
			System.out.println(xy[1]);

			try {
				// System.out.println("n:"+n);
				listX.add(pn, xy[0]);
				listY.add(pn, xy[1]);
			} catch (Exception e) {
				e.printStackTrace();
			}

			/*
			 * System.out.println("N:" + n); System.out.println("ArrayList
			 * listX:"+ listX.get(n)); System.out.println("ArrayList listY:"+
			 * listY.get(n));
			 */
		}
		++pn;
		coefsValid = false;
	}

	/**
	 * Return the value of the regression line function at x. (Implementation of
	 * Evaluatable.)
	 * 
	 * @param x
	 *            the value of x
	 * @return the value of the function at x
	 */
	public float at(int x) {
		if (pn < 2)
			return Float.NaN;

		validateCoefficients();
		return a0 + a1 * x;
	}

	/**
	 * Reset.
	 */
	public void reset() {
		pn = 0;
		sumX = sumY = sumXX = sumXY = 0;
		coefsValid = false;
	}

	/**
	 * Validate the coefficients. 计算方程系数 y=ax+b 中的a
	 */
	private void validateCoefficients() {
		if (coefsValid)
			return;

		if (pn >= 2) {
			float xBar = (float) sumX / pn;
			float yBar = (float) sumY / pn;

			a1 = (float) ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX
					* sumX));
			a0 = (float) (yBar - a1 * xBar);
		} else {
			a0 = a1 = Float.NaN;
		}

		coefsValid = true;
	}

	/**
	 * 返回误差
	 */
	public double getR() {
		// 遍历这个list并计算分母
		for (int i = 0; i < pn - 1; i++) {
			float Yi = (float) Integer.parseInt(listY.get(i).toString());
			float Y = at(Integer.parseInt(listX.get(i).toString()));
			float deltaY = Yi - Y;
			float deltaY2 = deltaY * deltaY;
			/*
			 * System.out.println("Yi:" + Yi); System.out.println("Y:" + Y);
			 * System.out.println("deltaY:" + deltaY);
			 * System.out.println("deltaY2:" + deltaY2);
			 */

			sumDeltaY2 += deltaY2;
			// System.out.println("sumDeltaY2:" + sumDeltaY2);

		}

		sst = sumYY - (sumY * sumY) / pn;
		// System.out.println("sst:" + sst);
		E = 1 - sumDeltaY2 / sst;

		return round(E, 4);
	}

	// 用于实现精确的四舍五入
	public double round(double v, int scale) {

		if (scale < 0) {
			throw new IllegalArgumentException(
					"The scale must be a positive integer or zero");
		}

		BigDecimal b = new BigDecimal(Double.toString(v));
		BigDecimal one = new BigDecimal("1");
		return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue();

	}

	public float round(float v, int scale) {

		if (scale < 0) {
			throw new IllegalArgumentException(
					"The scale must be a positive integer or zero");
		}

		BigDecimal b = new BigDecimal(Double.toString(v));
		BigDecimal one = new BigDecimal("1");
		return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).floatValue();

	}
}
3、线性回归测试类:

/**
 * File        : DataPoint.java
 * Author      : zhouyujie
 * Date        : 2012-01-11 16:00:00
 * Description : Java实现一元线性回归的算法,线性回归测试类,(可实现统计指标的预测)
 */
package com.zyujie.dm;

/**
 * <p>
 * <b>Linear Regression</b> <br>
 * Demonstrate linear regression by constructing the regression line for a set
 * of data points.
 * 
 * <p>
 * require DataPoint.java,RegressionLine.java
 * 
 * <p>
 * 为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
 * <p>
 * <b>回归直线方程如下: f(x)=a1x+a0 </b>
 * <p>
 * <b>斜率和截距的计算公式如下:</b> <br>
 * n: 数据点个数
 * <p>
 * a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2) <br>
 * a0=(SumY - SumY * a1)/n <br>
 * (也可表达为a0=averageY-a1*averageX)
 * 
 * <p>
 * <b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>
 * 第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
 * 
 * <p>
 * <b>拟合度计算:(即Excel中的R^2)</b>
 * <p>
 * *R2 = 1 - E
 * <p>
 * 误差E的计算:E = SSE/SST
 * <p>
 * SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
 * <p>
 */
public class LinearRegression {

	private static final int MAX_POINTS = 10;

	private double E;

	/**
	 * Main program.
	 * 
	 * @param args
	 *            the array of runtime arguments
	 */
	public static void main(String args[]) {
		RegressionLine line = new RegressionLine();

		line.addDataPoint(new DataPoint(1, 136));
		line.addDataPoint(new DataPoint(2, 143));
		line.addDataPoint(new DataPoint(3, 132));
		line.addDataPoint(new DataPoint(4, 142));
		line.addDataPoint(new DataPoint(5, 147));

		printSums(line);
		printLine(line);
	}

	/**
	 * Print the computed sums.
	 * 
	 * @param line
	 *            the regression line
	 */
	private static void printSums(RegressionLine line) {
		System.out.println("\n数据点个数 n = " + line.getDataPointCount());
		System.out.println("\nSum x  = " + line.getSumX());
		System.out.println("Sum y  = " + line.getSumY());
		System.out.println("Sum xx = " + line.getSumXX());
		System.out.println("Sum xy = " + line.getSumXY());
		System.out.println("Sum yy = " + line.getSumYY());

	}

	/**
	 * Print the regression line function.
	 * 
	 * @param line
	 *            the regression line
	 */
	private static void printLine(RegressionLine line) {
		System.out.println("\n回归线公式:  y = " + line.getA1() + "x + "
				+ line.getA0());
		System.out.println("误差:     R^2 = " + line.getR());
	}
	
	//y = 2.1x + 133.7   2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3
	//y = 2.1x + 133.7   2.1 * 7 + 133.7 = 14.7 + 133.7 = 148.4

}

我们运行测试类,得到运行结果:

1,136
2,143
3,132
4,142
5,147

数据点个数 n = 5

Sum x = 15.0
Sum y = 700.0
Sum xx = 55.0
Sum xy = 2121.0
Sum yy = 98142.0

回归线公式: y = 2.1x + 133.7
误差: R^2 = 0.3658

假如某公司:

1月收入,136万元
2月收入,143万元
3月收入,132万元
4月收入,142万元
5月收入,147万元

我们可以根据回归线公式:y = 2.1x + 133.7,预测出6月份收入:

y = 2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics