1970๋ ๋ ๋ฏธ๊ตญ ์ธ๊ตฌ์กฐ์ฌ ์๋น์ค(US Census Service)์์ Boston ์ง์ญ์ ์ฃผํ ๊ฐ๊ฒฉ ๋ฐ์ดํฐ๋ฅผ ์์งํ ๊ฐ๊ฒฉ ๋ฐ์ดํฐ๋ฅผ ์์งํ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ชจ๋ธ์ ๋น๋ํ๋ค.
Training Set
-> ๊ฐ๋ณ ์ฃผํ๊ฐ๊ฒฉ์ ์์ธก์ด ์๋๋ผ ์ง์ญ๋ณ ์ค๊ฐ ์ฃผํ๊ฐ๊ฒฉ ์์ธก์ด๋ค.
Regression ์๊ณ ๋ฆฌ์ฆ ์ฌ์ฉ
-> ์ฐ์์ ์ธ ์ฃผํ๊ฐ๊ฒฉ ์์ธก์ด๊ธฐ๋๋ฌธ์ Classification ์๊ณ ๋ฆฌ์ฆ์ ์ ํฉํ์ง ์๋ค.
| No. | Field | Explain |
|---|---|---|
| 1 | CRIM | ์ฃผํ์ด ์๋ ์ง์ญ์ ์ธ๋น ๋ฒ์ฃ์จ |
| 2 | ZN | 25000 sqft(์ฝ700ํ)์ด์์ ๋ ์ด ์ฃผ๊ฑฐ ์ง์ญ์ผ๋ก ์ค์ ๋ ๋น์จ |
| 3 | INDUS | ์์ด์ปค๋น ๊ณต์ ๋จ์ง์ ๋น์จ |
| 4 | CHAS | ์ฃผํ์ด ๊ฐ๊ฐ์ ์์นํ ๋น์จ |
| 5 | NOX | ์ฐํ์ง์ ๋๋๋ฅผ ์ด์ฉํ ์ค์ผ์ ๋ |
| 6 | RM | ์ฃผํ๋น ํ๊ท ๋ฐฉ์ ์ |
| 7 | AGE | 1940๋ ์ ์ ์ง์ด์ง ์ฃผํ์ ๋น์จ |
| 8 | DIS | ๋ณด์คํด ์ง์ญ ๊ณ ์ฉ ์ผํฐ๊น์ง์ ํ๊ท ๊ฑฐ๋ฆฌ |
| 9 | RAD | ๊ณ ์๋๋ก ์ ๊ทผ์ฑ์ ๋ํ ์ธ๋ฑ์ค |
| 10 | TAX | ์ฌ์ฐ์ธ(์ฃผํ๊ฐ๊ฒฉ $10k ๊ธฐ์ค) |
| 11 | PTRATO | ์ด๋ฑํ๊ต ํ์ - ์ ์๋์ ๋น์จ |
| 12 | B | ํ์ธ ์ธ๊ตฌ์ ๋น์จ |
| 13 | LSTAT | ์ ์๋์์ ์ธ๊ตฌ ๋น์จ |
| 14 | MEDV | $1000 ๋จ์์ ์ฃผํ ํ๊ท ๊ฐ |
!pip install pyspark==3.3.1 py4j==0.10.9.5
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.appName("Boston Housing Linear Regression") \
.getOrCreate()

data = spark.read.csv('./boston_housing.csv', header=True, inferSchema=True)
data.printSchema()
root
|-- crim: double (nullable = true)
|-- zn: double (nullable = true)
|-- indus: double (nullable = true)
|-- chas: integer (nullable = true)
|-- nox: double (nullable = true)
|-- rm: double (nullable = true)
|-- age: double (nullable = true)
|-- dis: double (nullable = true)
|-- rad: integer (nullable = true)
|-- tax: integer (nullable = true)
|-- ptratio: double (nullable = true)
|-- b: double (nullable = true)
|-- lstat: double (nullable = true)
|-- medv: double (nullable = true)
data.show()
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+
| crim| zn|indus|chas| nox| rm| age| dis|rad|tax|ptratio| b|lstat|medv|
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+
|0.00632|18.0| 2.31| 0|0.538|6.575| 65.2| 4.09| 1|296| 15.3| 396.9| 4.98|24.0|
|0.02731| 0.0| 7.07| 0|0.469|6.421| 78.9|4.9671| 2|242| 17.8| 396.9| 9.14|21.6|
|0.02729| 0.0| 7.07| 0|0.469|7.185| 61.1|4.9671| 2|242| 17.8|392.83| 4.03|34.7|
|0.03237| 0.0| 2.18| 0|0.458|6.998| 45.8|6.0622| 3|222| 18.7|394.63| 2.94|33.4|
|0.06905| 0.0| 2.18| 0|0.458|7.147| 54.2|6.0622| 3|222| 18.7| 396.9| 5.33|36.2|
|0.02985| 0.0| 2.18| 0|0.458| 6.43| 58.7|6.0622| 3|222| 18.7|394.12| 5.21|28.7|
|0.08829|12.5| 7.87| 0|0.524|6.012| 66.6|5.5605| 5|311| 15.2| 395.6|12.43|22.9|
|0.14455|12.5| 7.87| 0|0.524|6.172| 96.1|5.9505| 5|311| 15.2| 396.9|19.15|27.1|
|0.21124|12.5| 7.87| 0|0.524|5.631|100.0|6.0821| 5|311| 15.2|386.63|29.93|16.5|
|0.17004|12.5| 7.87| 0|0.524|6.004| 85.9|6.5921| 5|311| 15.2|386.71| 17.1|18.9|
|0.22489|12.5| 7.87| 0|0.524|6.377| 94.3|6.3467| 5|311| 15.2|392.52|20.45|15.0|
|0.11747|12.5| 7.87| 0|0.524|6.009| 82.9|6.2267| 5|311| 15.2| 396.9|13.27|18.9|
|0.09378|12.5| 7.87| 0|0.524|5.889| 39.0|5.4509| 5|311| 15.2| 390.5|15.71|21.7|
|0.62976| 0.0| 8.14| 0|0.538|5.949| 61.8|4.7075| 4|307| 21.0| 396.9| 8.26|20.4|
|0.63796| 0.0| 8.14| 0|0.538|6.096| 84.5|4.4619| 4|307| 21.0|380.02|10.26|18.2|
|0.62739| 0.0| 8.14| 0|0.538|5.834| 56.5|4.4986| 4|307| 21.0|395.62| 8.47|19.9|
|1.05393| 0.0| 8.14| 0|0.538|5.935| 29.3|4.4986| 4|307| 21.0|386.85| 6.58|23.1|
| 0.7842| 0.0| 8.14| 0|0.538| 5.99| 81.7|4.2579| 4|307| 21.0|386.75|14.67|17.5|
|0.80271| 0.0| 8.14| 0|0.538|5.456| 36.6|3.7965| 4|307| 21.0|288.99|11.69|20.2|
| 0.7258| 0.0| 8.14| 0|0.538|5.727| 69.5|3.7965| 4|307| 21.0|390.95|11.28|18.2|
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+
only showing top 20 rows
VectorAssembler๋ฅผ ์ฌ์ฉํด์ ์งํํ์
-> ์ฌ๋ฌ ์ปฌ๋ผ์ ํ๋์ ๋ฒกํฐ ์ปฌ๋ผ์ผ๋ก ๊ฒฐํฉํ๋ ๋ฐ ์ฌ์ฉํ๋ค.
inputCols์๋ ๊ฒฐํฉํ ์ปฌ๋ผ๋ค์ ์ด๋ฆ ๋ฆฌ์คํธ๋ฅผ ์ง์ ํ๋ค.
outputCols์๋ ๊ฒฐํฉ๋ ๋ฒกํฐ๋ฅผ ์ ์ฅํ ์ ์ปฌ๋ผ ์ด๋ฆ์ ์ง์ ํ๋ค.
from pyspark.ml.feature import VectorAssembler
feature_columns = data.columns[:-1] # data์ ๋ง์ง๋ง ์ปฌ๋ผ์ ์ ์ธํ ๋๋จธ์ง
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
vectorized_data = assembler.transform(data)ใ
vectorized_data.show()
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+--------------------+
| crim| zn|indus|chas| nox| rm| age| dis|rad|tax|ptratio| b|lstat|medv| features|
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+--------------------+
|0.00632|18.0| 2.31| 0|0.538|6.575| 65.2| 4.09| 1|296| 15.3| 396.9| 4.98|24.0|[0.00632,18.0,2.3...|
|0.02731| 0.0| 7.07| 0|0.469|6.421| 78.9|4.9671| 2|242| 17.8| 396.9| 9.14|21.6|[0.02731,0.0,7.07...|
|0.02729| 0.0| 7.07| 0|0.469|7.185| 61.1|4.9671| 2|242| 17.8|392.83| 4.03|34.7|[0.02729,0.0,7.07...|
|0.03237| 0.0| 2.18| 0|0.458|6.998| 45.8|6.0622| 3|222| 18.7|394.63| 2.94|33.4|[0.03237,0.0,2.18...|
|0.06905| 0.0| 2.18| 0|0.458|7.147| 54.2|6.0622| 3|222| 18.7| 396.9| 5.33|36.2|[0.06905,0.0,2.18...|
|0.02985| 0.0| 2.18| 0|0.458| 6.43| 58.7|6.0622| 3|222| 18.7|394.12| 5.21|28.7|[0.02985,0.0,2.18...|
|0.08829|12.5| 7.87| 0|0.524|6.012| 66.6|5.5605| 5|311| 15.2| 395.6|12.43|22.9|[0.08829,12.5,7.8...|
|0.14455|12.5| 7.87| 0|0.524|6.172| 96.1|5.9505| 5|311| 15.2| 396.9|19.15|27.1|[0.14455,12.5,7.8...|
|0.21124|12.5| 7.87| 0|0.524|5.631|100.0|6.0821| 5|311| 15.2|386.63|29.93|16.5|[0.21124,12.5,7.8...|
|0.17004|12.5| 7.87| 0|0.524|6.004| 85.9|6.5921| 5|311| 15.2|386.71| 17.1|18.9|[0.17004,12.5,7.8...|
|0.22489|12.5| 7.87| 0|0.524|6.377| 94.3|6.3467| 5|311| 15.2|392.52|20.45|15.0|[0.22489,12.5,7.8...|
|0.11747|12.5| 7.87| 0|0.524|6.009| 82.9|6.2267| 5|311| 15.2| 396.9|13.27|18.9|[0.11747,12.5,7.8...|
|0.09378|12.5| 7.87| 0|0.524|5.889| 39.0|5.4509| 5|311| 15.2| 390.5|15.71|21.7|[0.09378,12.5,7.8...|
|0.62976| 0.0| 8.14| 0|0.538|5.949| 61.8|4.7075| 4|307| 21.0| 396.9| 8.26|20.4|[0.62976,0.0,8.14...|
|0.63796| 0.0| 8.14| 0|0.538|6.096| 84.5|4.4619| 4|307| 21.0|380.02|10.26|18.2|[0.63796,0.0,8.14...|
|0.62739| 0.0| 8.14| 0|0.538|5.834| 56.5|4.4986| 4|307| 21.0|395.62| 8.47|19.9|[0.62739,0.0,8.14...|
|1.05393| 0.0| 8.14| 0|0.538|5.935| 29.3|4.4986| 4|307| 21.0|386.85| 6.58|23.1|[1.05393,0.0,8.14...|
| 0.7842| 0.0| 8.14| 0|0.538| 5.99| 81.7|4.2579| 4|307| 21.0|386.75|14.67|17.5|[0.7842,0.0,8.14,...|
|0.80271| 0.0| 8.14| 0|0.538|5.456| 36.6|3.7965| 4|307| 21.0|288.99|11.69|20.2|[0.80271,0.0,8.14...|
| 0.7258| 0.0| 8.14| 0|0.538|5.727| 69.5|3.7965| 4|307| 21.0|390.95|11.28|18.2|[0.7258,0.0,8.14,...|
+-------+----+-----+----+-----+-----+-----+------+---+---+-------+------+-----+----+--------------------+
only showing top 20 rows
train, test = vectorized_data.randomSplit([0.8, 0.2])
from pyspark.ml.regression import LinearRegression
model = LinearRegression(featuresCol="features", labelCol="medv").fit(train)
evaluation_summary = model.evaluate(test)
evaluation_summary
<pyspark.ml.regression.LinearRegressionSummary at 0x7cc1242a2590>
evaluation_summary.meanAbsoluteError
3.751751468553107
evaluation_summary.rootMeanSquaredError
5.266467730878813
evaluation_summary.r2
0.7002142469300772
predictions = model.transform(test)
predictions.select(predictions.columns[13:]).show()
+----+--------------------+------------------+
|medv| features| prediction|
+----+--------------------+------------------+
|32.7|[0.01301,35.0,1.5...|30.515144998498805|
|35.4|[0.01311,90.0,1.2...|30.631754822999135|
|24.5|[0.01501,80.0,2.0...|27.128954091991645|
|44.0|[0.01538,90.0,3.7...| 36.96650331489955|
|30.1|[0.01709,90.0,2.0...| 24.6221945758976|
|23.1|[0.0187,85.0,4.15...|24.785644795798774|
|42.3|[0.02177,82.5,2.0...|37.095922515731644|
|16.5|[0.02498,0.0,1.89...| 22.67389351259429|
|34.9|[0.0315,95.0,1.47...|29.789848917775686|
|33.4|[0.03237,0.0,2.18...|28.764672253072774|
|19.4|[0.03466,35.0,6.0...|23.414721268194157|
|45.4|[0.03578,20.0,3.3...| 38.85772413090571|
|23.5|[0.03584,80.0,3.3...|30.000658140048248|
|20.7|[0.03738,0.0,5.19...| 21.66209468225044|
|18.2|[0.04301,80.0,1.9...|13.568129819698605|
|20.5|[0.04337,21.0,5.6...|23.829695990158015|
|19.4|[0.04379,80.0,3.3...|25.147443347186538|
|20.6|[0.04527,0.0,11.9...|22.249892769719427|
|30.3|[0.04666,80.0,1.5...| 32.62958993192931|
|22.6|[0.04684,0.0,3.41...|27.323120358774855|
+----+--------------------+------------------+
only showing top 20 rows
์ฌ๊ธฐ ๊น์ง Linear Regression ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ณด๋ ๋ฐ๋ชจ๋ฅผ ์งํํด๋ณด์๋ค.
์ด๋ ๊ธฐ๋ณธ์ ์ธ ๊ฐ์ด๋๋ผ์ธ์ผ ๋ฟ์ด๊ธฐ ๋๋ฌธ์ ์ค์ฌ์ฉ์๋ ๋ฌธ์ ๊ฐ ์์ ๊ฒ์ด๋ค.