본문 바로가기

Deep Learning/Paper Review

Self-evolving vision transformer for chest X-ray diagnosis through knowledge distillation (2022, Nature Communications)

320x100
320x100

 

 

 의료 영역에서 주로 사용되는 모델 학습 방식은 지도학습(supervised learning)이다. CXR을 Radiologist가 판독한 결과를 해당 CXR의 정답 데이터(Label)로 학습에 사용한다. 직관적이고 간단한 학습 방법이면서도 모델의 성능 또한 좋게 나와서 의료 영역 뿐만 아니라 여러 영역에서 사용되는 학습 방법이다. 

 

 문제는 지도학습 방식에 꼭 필요한 고품질의 Label을 확보하는 과정이 어렵고 비용도 많이 든다는 것이다. CXR 촬영을 루틴으로 진행하는 건강검진센터만 놓고 생각해봐도 하루에 발생되는 CXR의 양은 쏟아지는 수준이다. 이러한 Large-scale CXR Datset의 모든 판독문을 분석해서 Labelling을 진행하는 것이 가장 이상적이겠지만, 기업 입장에서는 비용효율적인 측면도 같이 고려해야 하다보니 현실적으로 가능한 일은 아니다. 

 

 그래서 최근 연구들을 보면 의료 인공지능 학습에 비지도학습(unsupervised learning)이나 준지도학습(semi-supervised learning) 방식을 적용하는 것을 볼 수 있다.

  • [비지도학습]
    - Self-supervised learning / Contrastive learning
    - Label 없이 학습을 진행해서 feature maps끼리 clustering 한다는 느낌으로 학습 진행
  • [준지도학습]
    - Self-training
    - Large-scale Dataset의 일부만 labelling 진행한 다음 모델 학습 진행

각 학습 방식에 대해선 다른 블로그에서 이미 충분히 잘 다뤘기 때문에 이 정도로만 정리하고, 이제 본 글의 목적이었던 Self-evolving vision transformer for chest X-ray diagnosis through knowledge distillation의 핵심 내용에 대해 알아보자.

 

 


 

 

일단 그 전에!! 정리한 내용에 오류가 있을 수 있음을 먼저 말씀드립니다.

 

  1. CheXpert Dataset으로 pre-trained weights을 생성한 다음 이후 학습에 사용
  2. Small labeled dataset으로 initial model 생성 → first teacher로 사용 / 성능이 잘 나와야 이후 결과도 좋을 듯
  3. Knowledge Distillation의 teacher-student 구조로 모델 학습 진행
    (a) CXR 영상 한 장을 original, global crops, local crops로 생성해서 학습에 사용
            → 이 과정이 실제 junior가 senior로부터 배우는 과정과 비슷하다고 논문에서는 주장함
    (b) Teacher의 inference result를 pseudo-label로써 정답으로 여기고 Student 학습 진행
            → Self-training 방식
                 (Small labeled dataset으로 학습된 초기 모델을 활용하여 unlabeled dataset을 학습하기 때문)
            → Loss 계산은 Image 단위로 하지 않을까 싶음
    (c) Teacher와 Student의 feature maps을 clusterting 느낌으로 학습 진행
            → Self-supervised 방식
            → Loss 계산은 Batch 단위로 하지 않을까 싶음 (Image 단위로는 clustering이 불가능하기 때문)

  4. 중간중간 small labeld dataset으로 student 학습하는 단계(논문에서는 이를 'correction step'이라고 표현)를 추가해서 하나의 System으로 구축

 

 개인적으로는 correction step도 꽤나 중요한 역할을 하지 않았을까 싶다. 이유를 적어보자면 모델의 성능에 영향을 주는 것 중 하나가 weight initialization 방식이다. 어떤 값으로 모델의 initial weights을 초기화하느냐에 따라 성능이 달라지기 때문에 Xavier, He 방식 등이 제안됐었다. 이러한 관점으로 논문을 바라봐보면, Self-training과 Self-supervision을 또 하나의 새로운 weight initialization method로 여길 수 있을 것 같다. 성능에 긍정적인 영향을 줄 수 있게 self-training and self supervision을 통해 weights를 고루고루 섞어주다가, correction step에서 student model을 small labeled dataset으로 transfer learning 해주면, 처음에 CheXpert pre-trained weights을 사용했을 때보다 모델 성능이 좋아지게 된 건 아닐까 하는 생각을 해봤다. 뭐.. 실제로 Supplementary Fig.S3.를 봐보면 correction step을 제외했을 때 성능 저하가 가장 적게 나타나긴 했지만 말이다...ㅋ

 

< 참고 사이트 >

https://www.nature.com/articles/s41467-022-31514-x

https://jayhey.github.io/semi-supervised%20learning/2017/12/07/semisupervised_self_training/