Author: Karan Singhal, Shekoofeh Azizi, Tao Tu, S. Sara Mahdavi, Jason Wei, Hyung Won Chung, Nathan Scales, Ajay Tanwani, Heather Cole-Lewis, Stephen Pfohl, Perry Payne, Martin Seneviratne, Paul Gamble, Chris Kelly, Nathaneal Scharli, Aakanksha Chowdhery, Philip Mansfield, Blaise Aguera y Arcas, Dale Webster, Greg S. Corrado, Yossi Matias, Katherine Chou, Juraj Gottweis, Nenad Tomasev, Yun Liu, Alvin Rajkomar, Joelle Barral, Christopher Semturs, Alan Karthikesalingam, Vivek Natarajan
Arxiv Paper Link: https://arxiv.org/abs/2212.13138

Nature Paper Link: https://www.nature.com/articles/s41586-023-06291-2

Google Blog: https://sites.research.google/med-palm/

Google Cloud Blog: https://cloud.google.com/blog/topics/healthcare-life-sciences/sharing-google-med-palm-2-medical-large-language-model?hl=en 

 

## Key contributions
- 평가 데이터셋 구축 (Methods. 'Datasets')
- 평가 framework 정립 (Methods. 'Framework for human evaluation')
- 기존대비 성능 향상 (Human evaluation results)
- Flan-PaLM + Instruction prompt tuning -> Med-PaLM (Methods. 'Modeling')
- 한계 (Limitations)
- Fig1.

## Methods
### Datasets
- MultiMedQA benchmark
  - 1.MedQA(USMLE), 2.MedMCQA, 3.PubMedQA, 4.MedicationQA, 5.LiveQA TREC 2017, 6.MMLU clinical topics datasets, 7.HealthSearchQA

### Modeling
- 기본 LLM과 이 LLM을 의료 도메인에 특화시키기위한 방법들 소개
- Models
  - PaLM과 Flan-PaLM의 변형family들을 사용
- PaLM
  - 구글의 LLM 베이스라인
  - 540B (GPT3 175B) , Decoder-only transformer모델
  - 웹문서, 책, 위키피디아, 대화문, 깃헙코드의 7천8백억개의 토큰으로 1 epoch 학습 (cf. GPT3 3천억개, PaLM2 3조6천억개)
  - 256K Vocabs (GPT3 50K)
- Flan-PaLM
  - PaLM의 instruction-tuned 모델
  - 베이스라인 PaLM 대비 QA task 성능 향상
  - 8B, 62B, 540B
- Aligning LLMs to the medical domain
  - 안전성이 중요한 도메인 특성으로 데이터에 대한 모델의 align이 필수적 (cf. 통상 수만~수백만건)
  - 하지만 데이터를 모으기 어려운 도메인이므로 1.prompting과 2.prompt tuning의 data-efficient alignment전략을 사용 
- Prompting strategies
  - Few-shot, chain-of-thought, self-consistency의 세 가지 prompting 사용
- Few-shot prompting
  - 입력 context window 크기에 맞게 예시 개수 설정
  - 특정 크기 이상의 모델에서 창발하는 기능
  - 인증받은 전문가(의료진)와 함께 few-shot프롬프트 예시 설계
  - 데이터셋마다 각각 다른 예시 내용 및 개수 사용 (일반적으론 5개, 프롬프트가 긴 데이터셋은 3개)
- Chain-of-though prompting
  - Few-shot 예시의 reasoning 측면을 보강한 프롬프트 설계로 LLM에서 창발한 기능
  - 사람의 문제풀이 사고과정을 모사하여 단계단계 설명하는 방식으로 수학문제와 같은 논리성을 필요로하는 태스크에서 큰 효과
- Self-consistency prompting
  - 여러 음답을 샘플한 다음 가장 주가되는 응답을 선택하는 방법으로 multiple-choice에서 효과적인 전략
  - reasoning 과정이 복잡한 문제의 경우 잠재적인 정답 도출과정의 갯수가 여러가지일수 있는 점에 근거한 접근
  - 정답의 다양성을 위한 sampling temperature로는 0.7 사용
- Prompt tuing
  - Soft prompt로서 몇개의 학습가능한 token을 입력 앞에 붙이는 방법
  - 적게는 수십개의 데이터만으로도 가능
- Instruction prompt tuning
  - Flan-PaLM의 경우 few-shot과 COT만으로는 consumer medical question-answering datasets에서 낮은 
  - 데이터가 부족하므로 prompt tuning접근을 의료QA데이터를 사용한 instruction tuning에 사용하고 이를 'Instruction prompt tuning'이라 지칭
  - 이때, soft promt로 human-engineered prompt를 대신하는게 아니라 앞에 붙이는 접근으로 전체 도메인 데이터셋에 공통적으로 적용 (cf. 따라서 하나 이상의 여러 도메인에도 적용가능한 방법)
  - 여기선, 길이 100의 soft prompt를 사용했으며 PaLM의 embedding dimension이 18432이므로 학습해야할 파라메터 수는 1.84M
  - 파라메터는 [-0.5, 0.5]에서 uniform하게 초기화
  - AdamW로 {0.001, 0.003, 0.01}의 learning rate와 {0.01, 0.00001}의 weight decay에서 grid search
  - Batch size 32로 200 step 학습
  - held-out 데이터셋에 대한 응답을 의료진에 보여주고 최고의 checkpoint를 선택
  - 출력 범위가 넓은 생성형 모델의 특성상 metric으로 자동 선택할경우 사람의 판단과 다를경우가 많기 때문
  - 최종 선정된 hyperparameter는 0.003 learning rate에 0.00001 
- Putting it all together: Med-PaLM
  - 환자에 해를 끼치지 않으면서도 의료적 이해, 지식의 연결, 추론을 위해선 instruction prompt tuning의 적용에 특히나 좋은 데이터를 선별해야함
  - 미국과 영국은 전문 의료진으로 하여금 response-free 데이터에 레이블을 달고 애매한 질문-응답 페어는 필터링 후 최종 65개 예시 생성
 


## Model development and evaluation of performance
- Flan-PaLM만으로도 (MedQA(USMLE), MedMCQA, PubMedQA, MMLU)에서 기존 SOTA모델 대비 성능 크게 상승
- Fig2.

 


## Ablation
- 도메인 데이터(MedQA, MedMCQA, and PubMedQA)에 대한 Flan-PaLM성능 분석
### Instruction tuning이 도메인QA에 도움 (few-shot COT)

### Scaling이 도메인QA에 도움
- 8B보다 540B가 성능 두배
  - 결과: SI table6, SI fig1~2
  - 
###  COT(explanation) prompting
- 기본 few-shot대비 대비 대체로 성능 감소
  - 결과: SI table2, 질문예시: SI table23 ~ 28
- 정답까지의 다양한 reasoning이 가능할 뿐더러, 생성된 하나의 COT path가 가장 정확할 확률이 낮은것이 원인으로 추정
- 도메인 특화 prompt 사용시 일부만 효과 --> COT가 특정 문제 해결엔 도움이 될 수 있으나 도메인 지식주입엔 효과 없음
  - 결과: SI table5, 질문예시: SI table27~28

###  Self-consistency
- COT 11개 생성 후 정답 고르도록 했을때 기존 few-shot 대비 대체로 성능 증가 (오히려 성능 떨어지는 데이터셋도 있음)
  - 결과: SI table3,  응답예시: ED table6

## Key contributions 
- 평가 데이터셋 구축 (Methods. 'Datasets')
- 평가 framework 정립 (Methods. 'Framework for human evaluation')
- 기존대비 성능 향상 (Human evaluation results)
- Flan-PaLM + Instruction prompt tuning -> Med-PaLM (Methods. 'Modeling')
- 한계 (Limitations)
- (Fig1)

## Model development and evaluation of performance
- Flan-PaLM만으로도 (MedQA(USMLE), MedMCQA, PubMedQA, MMLU)에서 기존 SOTA모델 대비 성능 크게 상승 (Fig2)

## Ablation
- 도메인 데이터(MedQA, MedMCQA, and PubMedQA)에 대한 Flan-PaLM성능 분석
### Instruction tuning이 도메인QA에 도움 (few-shot COT)

### Scaling이 도메인QA에 도움
- 8B보다 540B가 성능 두배 (결과: SI table6, SI fig1~2)

###  COT(explanation) prompting
- 기본 few-shot대비 대비 대체로 성능 감소 (결과: SI table2, 질문예시: SI table23 ~ 28)
- 정답까지의 다양한 reasoning이 가능할 뿐더러, 생성된 하나의 COT path가 가장 정확할 확률이 낮은것이 원인으로 추정
- 도메인 특화 prompt 사용시 일부만 효과 --> COT가 특정 문제 해결엔 도움이 될 수 있으나 도메인 지식주입엔 효과 없음 (결과: SI table5, 질문예시: SI table27~28)

###  Self-consistency
- COT 11개 생성 후 정답 고르도록 했을때 기존 few-shot 대비 대체로 성능 증가 (오히려 성능 떨어지는 데이터셋도 있음) (결과: SI table3,  응답예시: ED table6)

### Uncertainty and selective prediction
- Self-concistency의 비율을 uncertainty로 정하고 threshold를 넘지 못하면 보류
- 보류비율이 증가할수록 정확도가 상승 (fig.3)
- 정확도가 중요한 의료분야에서 합리적인 접근이나 더 연구될 필요 있음

### Human evaluation results
- 일반이의 질문에 길게 응답하는 데이터셋에서 랜덤하게 질문을 샘플 (HealthSearchQA 100개, LiveQA 20개, MedicationQA 20개)
- 질문에 대한 1.의료진의 응답, 2.Flan-PaLM의 응답, 3.Med-PaLM의 응답을 다른 9명의 의료진이 합의된 기준들로 대해 평가 (Fig4~6)
- **Scientific consensus**
  - Flan-PaLM 61.9% 대비 Med-PaLM 92.9%로 intruction prompt tuning이 과학적 근거를 가진 대답을 하는데 효과적인 alignment 방법임을 확인 (fig4)
  - 하지만 학습 데이터가 과거의 과학적 합의에 기반하여 현재의 바뀐 내용을 반영못한다는 한계가 있어, 후속 연구에선 continual learning이나 RAG를 사용할 필요
- **Comprehension, retrieval and reasoning capabilities**
  - Instruction prompt tuning후 comprehension, retrieval, reasoning의 세가지 측면에서 모두 의료진의 답변과의 격차가 줄어드는것을 확인 (fig5)
- **Incorrect or missing content**
  - 응답에 잘못되거나 부적절한 내용이 있는지에 대한 평가에서 Med-PaLM이 18.7%로 Flan-PaLM의 16.1%보다 더 부정확
  - 중요한 정보 중 빠진 내용이 있는지에 대해선 Flan-PaLM이 47.6%로 매우 높은데 반해 Med-PaLM은 15.3%로 의료진의 11.1%과의 격차가 크게 감소 (fig4, ED table8)
  - Instruction prompt tuning이 보다 자세한 답변을 하도록하여 빠진 내용은 줄어드나 이로인해 부정확한 정보를 생성할 가능성은 상승하는것을 확인
- **Possible extent and likelihood of harm**
  - 응답대로 행했을 경우 위험해질 가능성에 대한 평가에선 Flan-PaLM의 29.7%대비 Med-PaLM은 5.9%로 의료진의 5.7%와 거의 유사한 결과 (fig4)
- **Bias for medical demographics**
  - 응답이 편향된 인구 집단에만 적절한지에 대해 Flan-PaLM은 7.9%인데 반해 Med-PaLM은 0.8%로 의료진의 1.4%보다도 더 낮은 결과
  - 단, 질문 자체가 공정성을 고려하여 만들어졌음을 주의
- **Lay user assessment**
  - 일반인으로 하여금 응답이 도움이 되는지에 대해선 Med-PaLM의 응답이 Flan-PaLM의 60.6%보단 크게 높은 80.3%를 보였으나 의료진의 91.1%보단 낮은 평가
  - 응답이 궁금한 부분을 다루었는지에 대해선 Med-PaLM의 응답이 의료진의 95.9%에 근접한 94.4% 
  - 결과적으로 instruction prompt tuning이 일반 사용자로 하여금 만족스러운 답변을 생성하도록 하는데 도움이 되나 여전히 의료진의 응답에 가까워지려면 해결해야할 작업이 많이 남은것을 알 수 있음
  
## Discussion
- 모델의 스케일을 키우고 instruction tuning을 하는것은 세부 도메인의 테스크를 하는데 크게 도움
- 사전학습 corpus에 테스트 데이터셋이 소량 포함되어있긴하나, 스케일링에 의한 암기 이상의 뛰어난 성능 확인
- Flan-PaLM이 도메인 데이터로만 학습시킨 모델들보다 더 나은 성능을 보여줌에 따라, 의료 도메인에서의 QA 능력은 스케일에 의해 향상된다고 결론
- 하지만 일반인의 의료 질의에 대한 QA테스크에서는 단순한 스케일링만으로는 부족함
- 이에, Instruction prompt tuning을 사용하여 정확성, 일관성, 안정성, 위해성, 편향성, 유용성의 모든 측면에서 도메인 전문가(의료진)에 보다 가까운 성능을 확보

## Limiations
### Key LLM capabilities for this setting
- 도메인 전문가 수준까지 가기위해 필요한 것들
  1. 권위 있는 의료 소스에 grounding되면서도 시간에 따라 변경되는 컨센서스를 반영
  2. 불확실성을 탐지하고 이에 대해 사용자와 소통
  3. 다중언어 사용
  4. 안전성을 고려한 alignment

## Methods
### Datasets
- MultiMedQA benchmark
  - medical knowledge를 필요로 하는 질문, medical research comprehenshion skill을 필요로 하는 질문, 사용자의 의도를 파악하고 이에 필요한 정보를 요구하는 질문들로 구성
  - 1.MedQA(USMLE), 2.MedMCQA, 3.PubMedQA, 4.MedicationQA, 5.LiveQA TREC 2017, 6.MMLU clinical topics datasets, 7.HealthSearchQA
  - 데이터셋 마다 서로 다른 형태
    1. format: 4~5지선다 vs 긴 답변
    2. capabilities tested: 지식 평가 vs 추론 평가
    3. domain: open vs closed
    4. question source: 전문 시험 vs 연구자 vs 일반인
    5. labels and metadata: label, 설명, 데이터 소스의 포함 여부
  - Table1
- 긴 답변이 필요한 데이터셋(MedMCQA, PubMedQA, LiveQA, MedicationQA)의 경우 답변이 전문적이지 않고 일관적이지 않으므로 삭제
  - BLEU와 같은 수치평가는 안정성이 중요한 도메인에서의 긴 답변을 평가하는데 적합하지 않으므로 기존 답변을 BLEU평가에 대한 ground truth로 사용할 필요없음
  - 인증된 전문가로부터 새롭게 레이블링
- 완전하지않으므로 후속 연구에선 EMR이나 전임상 지식으로도 확대 예정

### Framework for human evaluation
- Clinician evaluation
  - 일반인 질의 데이터(LiveQA, MedicationQA, HealthSearchQA)의 질문에 대한 LLM의 응답을 9명의 미국, 영국, 인도의 의료진으로 하여금 평가
  - 12 축에 대해 응답 평가 (ED table2)
    1. 과학적 합의사항과의 일치 여부
    2. 육체 및 정신적으로 해로운지에 대한 심각성 
    3. 육체 및 정신적으로 해로운지에 대한 가능성
    4. 문제에 대한 전반적 이해를 했다는 증거 포함 여부
    5. 과학적 인용에 대한 증거 포함 여부
    6. 논리적 추론에 대한 증거 포함 여부
    7. 문제에 대한 잘못된 전반적 이해를 했다는 증거 포함 여부
    8. 과학적으로 잘못된 인용에 대한 증거 포함 여부
    9. 잘못된 논리적 추론에 대한 증거 포함 여부
    10. 적절하지 않은 내용의 포함 여부
    11. 빠트리면 안될 정보를 빠트렸는지의 여부
    12. 특정 인구집단에 편향되어 있는지 여부
   - 평가 가이드라인 합의를 위해 25개 예시에 대해 3명의 의료진이 수렴까지 반복
- Lay user evaluation
  - 일반인 질의 데이터(LiveQA, MedicationQA, HealthSearchQA)의 질문에 대한 LLM의 응답을 5명의 비의료 일반인으로 하여금 평가
  - 2 축에 대해 응답 평가 (ED table3)
    1. 응답의 질의 의도 파악 여부
    2. 응답의 도움 여부
### Modeling
- 기본 LLM과 이 LLM을 의료 도메인에 특화시키기위한 방법들 소개
- Models
  - PaLM과 Flan-PaLM의 변형family들을 사용
- PaLM
  - 구글의 LLM 베이스라인
  - 540B (GPT3 175B) , Decoder-only transformer모델
  - 웹문서, 책, 위키피디아, 대화문, 깃헙코드의 7천8백억개의 토큰으로 1 epoch 학습 (cf. GPT3 3천억개, PaLM2 3조6천억개)
  - 256K Vocabs (GPT3 50K)
- Flan-PaLM
  - PaLM의 instruction-tuned 모델
  - 베이스라인 PaLM 대비 QA task 성능 향상
  - 8B, 62B, 540B
- Aligning LLMs to the medical domain
  - 안전성이 중요한 도메인 특성으로 데이터에 대한 모델의 align이 필수적 (cf. 통상 수만~수백만건)
  - 하지만 데이터를 모으기 어려운 도메인이므로 1.prompting과 2.prompt tuning의 data-efficient alignment전략을 사용 
- Prompting strategies
  - Few-shot, chain-of-thought, self-consistency의 세 가지 prompting 사용
- Few-shot prompting
  - 입력 context window 크기에 맞게 예시 개수 설정
  - 특정 크기 이상의 모델에서 창발하는 기능
  - 인증받은 전문가(의료진)와 함께 few-shot프롬프트 예시 설계
  - 데이터셋마다 각각 다른 예시 내용 및 개수 사용 (일반적으론 5개, 프롬프트가 긴 데이터셋은 3개)
- Chain-of-though prompting
  - Few-shot 예시의 reasoning 측면을 보강한 프롬프트 설계로 LLM에서 창발한 기능
  - 사람의 문제풀이 사고과정을 모사하여 단계단계 설명하는 방식으로 수학문제와 같은 논리성을 필요로하는 태스크에서 큰 효과
- Self-consistency prompting
  - 여러 음답을 샘플한 다음 가장 주가되는 응답을 선택하는 방법으로 multiple-choice에서 효과적인 전략
  - reasoning 과정이 복잡한 문제의 경우 잠재적인 정답 도출과정의 갯수가 여러가지일수 있는 점에 근거한 접근
  - 정답의 다양성을 위한 sampling temperature로는 0.7 사용
- Prompt tuing
  - Soft prompt로서 몇개의 학습가능한 token을 입력 앞에 붙이는 방법
  - 적게는 수십개의 데이터만으로도 가능
- Instruction prompt tuning
  - Flan-PaLM의 경우 few-shot과 COT만으로는 consumer medical question-answering datasets에서 낮은 
  - 데이터가 부족하므로 prompt tuning접근을 의료QA데이터를 사용한 instruction tuning에 사용하고 이를 'Instruction prompt tuning'이라 지칭
  - 이때, soft promt로 human-engineered prompt를 대신하는게 아니라 앞에 붙이는 접근으로 전체 도메인 데이터셋에 공통적으로 적용 (cf. 따라서 하나 이상의 여러 도메인에도 적용가능한 방법)
  - 여기선, 길이 100의 soft prompt를 사용했으며 PaLM의 emb_dim이 18432이므로 학습해야할 파라메터 수는 1.84M
  - 파라메터는 [-0.5, 0.5]에서 uniform하게 초기화
  - AdamW로 {0.001, 0.003, 0.01}의 learning rate와 {0.01, 0.00001}의 weight decay에서 grid search
  - Batch size 32로 200 step 학습
  - held-out 데이터셋에 대한 응답을 의료진에 보여주고 최고의 checkpoint를 선택
  - 출력 범위가 넓은 생성형 모델의 특성상 metric으로 자동 선택할경우 사람의 판단과 다를경우가 많기 때문
  - 최종 선정된 hyperparameter는 0.003 learning rate에 0.00001 
- Putting it all together: Med-PaLM
  - 환자에 해를 끼치지 않으면서도 의료적 이해, 지식의 연결, 추론을 위해선 instruction prompt tuning의 적용에 특히나 좋은 데이터를 선별해야함
  - 미국과 영국의 전문 의료진으로 하여금 response-free 데이터에 레이블을 달도록 하고 애매한 질문-응답 페어는 필터링 후 최종 65개 학습 예시 생성


+ Recent posts