Tweet classification in SparkNLP with BERT

pnpy6elp·2022년 10월 24일
0

Distributed System

목록 보기
2/4

1. install SparkNLP

자신의 Spark version에 맞는 버전을 설치해야 한다!!!
나는 pyspark를 실행하면 자동으로 설정된 SparkSession이 실행되어서 spark-default.conf와 pyspark도 수정했다.

spark-default.conf

spark.jars.packages             com.johnsnowlabs.nlp:spark-nlp-spark24_2.11:3.4.4

pyspark (나같은 경우는 pyspark를 사용해서 해당 파일을 수정했다.)

export PYSPARK_SUBMIT_ARGS="--packages com.johnsnowlabs.nlp:spark-nlp-spark24_2.11:3.4.4 pyspark-shell"

2. BERT Sentence Embedding

Pipeline

  • BERT Sentence Embedding의 default model은 sent_bert_based_cased이다. SparkNLP를 사용하면 빠르게 Bert embedding model를 사용할 수 있다. 실제로 embedding 자체는 얼마 안걸렸지만 classifier를 load하는 시간 때문에 오래 걸렸다.
documentAssembler = DocumentAssembler().setInputCol("text").setOutputCol("document")
sentenceDetector = SentenceDetector().setInputCols(["document"]).setOutputCol("sentence")
sentenceEmbeddings = BertSentenceEmbeddings.pretrained("sent_bert_base_cased", "en").setInputCols("sentence").setOutputCol("sentenceEmbeddings")
embeddingsFinisher = EmbeddingsFinisher().setInputCols(["sentenceEmbeddings"]).setOutputCols("finishedEmbeddings").setOutputAsVector(True)

pipeline = Pipeline().setStages([documentAssembler,sentenceDetector,sentenceEmbeddings,embeddingsFinisher])

model = pipeline.fit(train)

Transforming

trainData = model.transform(train)
testData = model.transform(test)

3. Classification using DNN

SparkNLP는 Word Embedding model을 적용할 수 있는 Deep Neural Network 기반 classifier를 제공한다. ClassifierDLModel을 사용하면 다른 user의 pretrained model을 사용할 수 있다. 하지만 fine-tuning은 제공하고 있지 않다.

Train model

classifierDL = ClassifierDLApproach() \
.setInputCols(["sentenceEmbeddings"]) \
.setOutputCol("prediction") \
.setLabelColumn("label") \
.setBatchSize(128
.setMaxEpochs(20) \
.setLr(0.005)
.setDropout(0.5)\
.setEnableOutputLogs(True) # 나 같은 경우, hadoop의 user/root/annotator_logs에 저장됨

classifierDLModel = classifierDL.fit(trainData)

Prediction

prediction = classifierDLModel.transform(testData)
prediction.select("label","prediction.result").show(10, truncate=False)
profile
Distributed b2ng

0개의 댓글