Regresión Lineal en Arduino

La regresión lineal es útil cuando tenemos dos variables relacionadas linealmente o cuya relación se puede aproximar a la ecuación de un línea. La forma habitual de verlo es como una nube de puntos donde se calcula una recta que minimiza la distancia a todos los puntos. Se puede ver un ejemplo extraído de la wikipedia debajo:

Linear regression.svg
De SewaquTrabajo propio, Dominio público, Enlace

Nosotros vamos a usarlo para calcular de forma automática la relación entre dos variables, siempre y cuando esta sea lineal. Al final del proceso obtenemos la ecuación de una recta en la forma:

y = mx + b

Para ello se parte de una muestra de valores (x, y). Estos valores no deben de coincidir exactamente con la recta (serian los puntos azules de la imagen).

Las formulas para calcular lo parámetros:

m = σ²(x,y) / σ²(x)

b = µ(y) – m * µ(x)

Siendo σ²(x,y) la covarianza, σ²(x) la varianza y µ(x) la media

El mayor problema que hay en el caso de un arduino es que no hay memoria para guardar todas las muestras así que hemos de encontrar una forma de ir calculando estos valores estadísticos sin guardar todo el histórico de datos, de forma incremental. La solución la encontré aquí.  Así no hace falta guardar todos os valores, esto ademas permite que el algoritmo este siempre aprendiendo, aunque lo veremos más adelante.

También podemos calcular la correlación entre variables. Contra más cercano a a 1 o -1 mejor ya que indica una alta correlación entre ambas variables, sin embargo si es cercano a 0 indica una nula correlación.

r = σ²(x,y) / (σ(x) * σ(y))

Siendo r la correlación y σ(x) la desviación estándar, igual a √σ²(x).

Al final los cálculos en C para el arduino de m y b quedan en:


meanX = meanX + ((x-meanX)/n);
meanX2 = meanX2 + (((x*x)-meanX2)/n);
varX = meanX2 - (meanX*meanX);

meanY = meanY + ((y-meanY)/n);
meanY2 = meanY2 + (((y*y)-meanY2)/n);
varY = meanY2 - (meanY*meanY);

meanXY = meanXY + (((x*y)-meanXY)/n);

covarXY = meanXY - (meanX*meanY);

m = covarXY / varX;
b = meanY-(m*meanX);

Y para la correlación:

double stdX = sqrt(varX);
double stdY = sqrt(varY);
double stdXstdY = stdX*stdY;
double correlation;

if(stdXstdY == 0){
  correlation = 1;
} else {
  correlation = covarXY / stdXstdY;
}

 

Limites

En la vida real lo más probable es que no nos interese aplicar la regresión lineal a todos los posibles valores de la función solo a los que están contenidos entre  dos punto para ello se fijan los limites en los que la linea esta definida en el constructor.

LinearRegression lr = LinearRegression(0,100);

Fijar el valor de N

Una de las ventajas de este algoritmo es que es capaz de ir recalculando su valor de manera dinámica según se añaden más datos. El problema es que según crece n los nuevos datos cambian la media cada vez menos. Si por ejemplo tratamos con un valor que cambia con el tiempo y queremos que la recta se vaya adaptando debemos fijar el valor de n de tal forma que lo permita. Lo complicado esta en acertar con ese valor, si es muy bajo sera muy sensible al ruido y si es muy alto necesitara mucho tiempo para adaptarse a los nuevo valores.

En este caso se puede hacer con la función fixN

lr.fixN(100);

Regresión lineal segmentada

El principal problema de la regresión lineal es que es lineal. ¿Y si necesitamos que se adapte a una distribución que no sea una linea?. Se puede conseguir algo más de flexibilidad definiendo varias lineas de regresión. Podemos dividir una distribución de puntos en varios tramos y en cada tramo calcular su linea de regresión, así en lugar de una solo linea tenemos un conjunto de ellas, para ello podemos usar lo limites que se le pasa en el constructor


LinearRegression lr = LinearRegression(0,20);

LinearRegression lr = LinearRegression(0,40);

LinearRegression lr = LinearRegression(0,60);

LinearRegression lr = LinearRegression(0,80);

LinearRegression lr = LinearRegression(0,100);

Ejemplo


#include 

LinearRegression lr = LinearRegression(0,100);

void setup() {
  lr.learn(1,3);
  lr.learn(2,4);
  lr.learn(3,5);
  lr.learn(4,6);
  lr.learn(5,7);
  lr.learn(6,8);

  Serial.begin(9600);

}

void loop() {
  Serial.print("Result: ");
  Serial.println(lr.calculate(6));

  Serial.print("Correlation: ");
  Serial.println(lr.correlation());

  Serial.print("Values: ");

  lr.getValues(values);
  Serial.print("Y = ");
  Serial.print(values[0]);
  Serial.print("*X + ");
  Serial.println(values[1]);

  delay(2000);
}

La librería esta disponible en github

Responder

Introduce tus datos o haz clic en un icono para iniciar sesión:

Logo de WordPress.com

Estás comentando usando tu cuenta de WordPress.com. Cerrar sesión /  Cambiar )

Google photo

Estás comentando usando tu cuenta de Google. Cerrar sesión /  Cambiar )

Imagen de Twitter

Estás comentando usando tu cuenta de Twitter. Cerrar sesión /  Cambiar )

Foto de Facebook

Estás comentando usando tu cuenta de Facebook. Cerrar sesión /  Cambiar )

Conectando a %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.