Author: Ayush Jain, Norio Kosaka, Kyung-Min Kim, Joseph J Lim
Paper Link: https://openreview.net/forum?id=MljXVdp4A3N 

Site: https://sites.google.com/view/varyingaction

Code: https://github.com/clvrai/agile

 

0. Abstract

  • 지능이 있는 개체는 현재 행할 수 있는 action들의 종류에 따라서 유동적으로 task를 풀 수 있는 반면, 보편적인 RL은 고정된 action set을 가정하고 있음.
  • 예를 들어 목공 수리를 하는 task에서 '못질을 하는 action'은 '망치를 드는 action'이 있을때만 의미가 있음.
  • 본 연구에서는, 이러한 action사이의 상관관계를 활용하기위해 graph attention network를 활용가능한 action들에 적용하는 방법을 제안함.
  • 결과적으로 이 관계성 접근방법을 사용할 경우, value-based 및 pollicy-based RL알고리즘 모두 에서 서로 연관된 action을 활용하는것이 가능함을 확인했으며
  • action sapce가 변하는 추천시스템 및 물리적 reasoning과 같은 문제에서 기존의 비관계성 아키텍처들보다 뛰어난 성능을 보이는것을 확인함.

 

1. Introduction

  • 액자를 벽에 거는 task가 있을때, 망치가 있을땐 못을 사용하면 되지만 후크가 있을땐 접착테잎을 사용해야함.

  • 즉, 최선의 의사결정은 환경 뿐만이 아니라 현재 활용가능한 action에도 의존.
  • 기존 RL은 fixed action space를 가정하고 있기때문에 최근엔 RL에서도 변화하는 action space 혹은 unseen action 문제를 다루고자하는 연구가 발표되고 있으나, 위 예시와 같은 action들 사이의 interdependence에대한 학습을 다루진 않았음. 
  • Varing action sapce에서의 interdepence문제로는 매일 추천해야할 기사의 set이 바뀌는 recommender system 혹은 공구/목적/스킬에 따른 물리적 reasoning이 있음
  • 본 연구에서는 graph attention network (GAT)를 핵심요소로 하는 AGILE, Action Graph for Interdependence Learning 이라는 policy 아키텍처를 제안하며, 
    1. 입력으로서의 action set의 요약과 2. action들 사이의 관계정보 학습을 그 목적으로 함. 

 

2. Related Work

2.1 Stochastic Action Sets

  • 정해진 전체 action pool에서 활용가능한 action set이 랜덤하게 샘플되는 경우를 stochastic action sets라고 하며, 기존 연구에서는 활용불가능한 action의 경우 확률분포의 출력을 masking하는 방식 등을 사용함.
  • 하지만 전체 action pool이 미리 정해져 있는것은 추천시스템과 같이 unseen item을 자주 받는 경우엔 실용성이 떨어짐.
  • 또한 기존 연구와 같이 매 timestep마다 action set이 바뀌는것 역시 실용성이 떨어짐.
  • 이에 본 연구에서는 한 episode에서는 샘플된 action set이 유지되도록하며, unseen action을 마주하는 상황에 대한 한계를 다루고자 action representation을 사용하는 방법을 활용함.

2.2 Action Representations

  • 넓은 action space, transfer learning, shared structural action 등의 문제를 위해 action representation을 활용하는 방법이 기존 논문들에서 사용됨.
  • 본 연구에서는 agent가 base action pool에 대한 사전지식을 활용하는것을 피하고 unseen action 문제를 다루고자 action representation을 사용함.

2.3 List-wise Action Space 

  • action의 선택이 large set $\mathcal{I}$에서 $k$아이템의 subset인, 즉 combinatorial action space $\begin{pmatrix} \mathcal{I} \\ k \end{pmatrix}$에서 최적의 action list를 찾는 강화학습 문제를 list-wise RL 혹은 slate RL이라고 함.
  • 본 연구에서는 제안하는 AGILE policy 아키텍처가 list-wise RL에도 적용가능함을 보여주고자 함.

2.4 Relational Reinforcement Learning

  • GNN은 관계성이 중요한 task를 다루는 기존 RL연구들에서 사용 됨.
  • 본 연구에서는 task를 풀기위해 action들 사이의 interaction이 중요한 문제들을 다루고자하며, 이 과정에서 의미있는 action interaction을 모델링 하기 위해 graph attention network (GAT)를 활용함.

 

3. Problem Formulation

  • 위에서 든 벽에 액자를 거는 예시와 같이, 주어진 action set에서 최선의 행동을 위한 action간의 interdependence를 학습하는것이 여기서 풀고자하는 문제의 핵심.

3.1 Reinforcement Learning with Varying Action Space

  • 강화학습으로 문제를 접근하기위해 다음의 MDP를 정의

$\left\{ \mathcal{S}, \mathbb{A}, \mathcal{T}, \mathcal{R}, \gamma \right\}$

  • 이때 $\mathbb{A}$ 는 countably infinite한 action의 집합.
  • 무한한 action set을 handling하고자, 여기선 추가적으로 $a \in \mathbb{A}$에 대한 $D$-차원의 action representation $c_a \in \mathbb{R}^D $ 를 정의.
  • Action의 subset인 $\mathcal{A} \subset  \mathbb{A}$와 이에 해당하는 representation $\mathcal{C}$ 가 주어졌을 때, 여기서 agent의 목적은 unseen action들또한 포함한 subset에 대해 다음의 retrun 을 최대화하는 policy $\pi\left ( a \mid s, \mathcal{A} \right )$을 학습하는것

$\mathbb{E}_{\mathcal{A}\subset\mathbb{A}}\left [ \sum_t \gamma^{t-1}r_t  \right ]$

 

3.2 Challenges of Varying Action Space

  • Varing action sapce에서의 interdepence문제에는 챌린지가 있으며 그 대응안은 다음과 같음.
    1. 모든 action이 다 주어지는것이 아니므로 policy framework는 유연해야함. 이를 위해 action representation $\mathcal{C}$을 사용.
    2. 현재 환경에서 활용 가능한 action set이 변하는 경우에는 기존의 state space $\mathcal{S}$만으로는 환경의 상태를 완전히 표현할 수 없으므로, $\mathcal{S}$는 Markovian이 아님. 이에, action set representation을 추가한 hyper state space $\mathcal{S}^{'}$를 새롭게 정의. $\mathcal{S}^{'}=\left\{ s\circ \mathcal{C}_{\mathcal{A}} : s \in \mathcal{S, A} \subset \mathbb{A} \right\}$
    3. 활용가능한 action들 사이의 interdependence가 학습되어야함.  구체적으로, 최적의 agent는 미래의 활용 가능한 action들 $c_{a_i} \forall a_i \in \mathcal{A}$  과 현재의 활용가능한 action들 $c_{a_t}$ 의 특성간의 관계를 explicitly 모델링 할 수 있어야 됨.

 

4. Approach

  • 제안하는 접근의 핵심은 GNN을 사용하여 action representation set를 임베딩하는 동시에 action들 사이의 관계를 학습하는 것. 

4.1 AGILE: Action Graph for Independence Learning

  • Action representation $\mathcal{C}$의 리스트와 state encoding이 주어졌을때, 이 를 각각 concat하여 fully-connected action graph를 만듬.
  • 여기에 graph attention network (GAT)를 사용하여 action간의 관계를 attention weight로서 추론함.
    예) Figure 2.에서 대포와 불은 높은 attention weight를 가짐
  • Utility network는 state encoding, GAT의 결과인 relational action representation, 그리고 meal-pooling을 한 action summary를 사용하여 RL알고리즘의 Q나 logit을 계산. 

Action Graph:

  • State에 따라 action들 사이의 관계가 달라짐.
    (예: 스크류드라이버와 전동드릴은 나사못과의 관계가 유사하지만, 가구조립시엔 드라이버가 좀 더 선호되며 벽에 사용할땐 드릴이 좀 더 선호됨)
  • 즉, action representation만을 고려하여 graph를 만드는것이 아닌, state를 결합한 새로운 action representation $c^{'}_{a_i}=\left ( s,c_{a_i} \right )$를 각 노드의 feature로 하여 fully-connected action graph $\mathcal{G}$를 구성함. 
  • 이 후 실험파트에서, 이와같이 state를 추가할 경우에 더 optimal에 가까운 결과를 보여주는것을 다룸. 

Graph Attention Network (GAT):

  • 만들어진 action graph $\mathcal{G}$를 GAT의 입력으로 넣어, 주어진 action set 중에서 서로 연관성이 큰 action들에 더 포커스를 하도록 학습.
  • GAT에 대한 설명은 고려대학교 DMQA 연구실 소속이었던 강현규님의 세미나를 참고.

강현규님의 GAT 발표 슬라이드에서 발췌

  • 충분한 propagation이 가능하도록하기위해 ELU를 사이에 연결한 2개의 graph attention layer를 구성.
  • 이 후 실험파트에서, 두번째 레이어 다음의 residual connection은 중요한 반면 multi-head attention은 영향이 없는것을 결과로서 다룸. 

Action Set Summary:

  • GAT의 출력은 relational action representation $\mathcal{C}^R=\left\{ c^R_{a_0}, \cdots ,c^R_{a_k} \right\}$이며, 각 action representation은 가능한 다른 action과 그 관계에 대한 정보를 내포하고 있음.
  • 앞서 MDP정의에서 다룬바와 같이, 현재 주어진 action set을 state변수로 고려하는 목적에서 action set 정보를 요약하기위해 mean-pooling을 다음과 같이 수행함. $\overline{c}^R=\frac{1}{K}\sum_{i=1}^{K} c^R_{a_i}$ 

Action Utility:

  • 앞서 계산관 값들은 RL네트워크에 전달하기 위한 utility score를 계산하는 utility network 아키텍처 $\pi_u$에 입력으로 전달 됨. 즉, $\pi_a\left ( c^R_a, s, \overline{c}^R \right )$

 

4.2 Training AGILE framework with Reinforcement Learning

  • AGILE 아키텍처의 학습은 PPO, DQN, CDQN을 사용하여 end-to-end로 각각 수행.
    (RL에 대한 자세한 내용은 이 포스팅에선 생략. CQDN은 list-wise action 문제에 대한 RL알고리즘.)

 

5. Environments

  • AGILE알고리즘의 실험은 세가지 서로 다른 특징을 가진 환경에서 진행.
    1) Dig Lava Grid World: 샘플링된 skill을 사용한 최단경로 찾기
    2) CREATE: 물건을 목적지로 옮기기위해 샘플링된 도구들을 physical reasoning하여 선택하기.  
    3) Recommender Systems: 

5.1 Dig Lava Grid Navigation

  • 최단 경로 찾기의 대표적인 toy 환경인 기존의 2D Grid World에서 디폴트 5개 action 외에 4개의 추가 action 중 2개가 랜덤하게 더 주어질 경우, 이를 활용하여 최단 경로를 개선하는 문제.
  • 용암에 들어간 후 다음 스텝에서 용암의 색깔에 맞는 땅파기 스킬을 사용하면 용암이 사라지나, 두 스텝 연속으로 용암에 들어 있으면 해당 에피소드는 실패.
  • RL알고리즘으론 PPO 사용.

5.2 Chain Reaction Tool Environment: CREATE

  • 2차원 공간 상에서 주어진 도구들을 사용하여  빨간색공을 목적지로 옮기는 문제 .
  • 기본도구와 기본도구를 활용하기위한 동력장치가 각각 샘플링되어 주어졌을 때, 이를 사이의 물리적 관계를 파악하고 
  • RL알고리즘으론 PPO사용.

5.3 Recommender Systems

  • RecSys, 즉 추천시스템은 variying action space RL문제로 볼수 있음.
    (예: 매일 새롭운 기사나 유투브 영상이 올라오며 이중에서 추천이 이루어짐)
  • Complementary Product Recommendation (CPR): 추천을 할 때 high level에선 관련이 있지만 low level에서는 다양성을 늘리는것을 말하며 long-term 관점에서 서비스를 운영할 때 매우 중요한 부분.
    (예: primary category는 셔츠 및 바지와 같은것이라면 subcategory는 색깔)

$CPR = \frac{Entropy\;of\;subcategory}{Entropy\;of\;category}$

  • 본 연구에서는 user preference에 더해 현실적인 시나리오를 만들고자 listwise item이 잘 추천된지에 대한 metric으로 CPR을 추가적으로 사용.
  • CPR의 정의에 따라, CPR을 최적화 하려면 추천 agent는 가장 보편적인 primary 아이템을 찾아
  • 이 후 실험에서는, 추천 시뮬레이션 환경에서는 user의 클릭수로 함축적으로 CPR를 최대화하고, 실제 서비스 데이터에서는 reward로서 명시적으로 CPR을 최대화.

5.3.1 Simulated Recommender System: RecSim

  • 사용자 인터렉션을 시뮬레이션하고 강화학습을 적용하기위해 구글에서 개발한 환경인 RecSim (Github)을 listwise recommendation task로 확장. (RecSim에 대한 자세한 내용은 DataBro 님의 포스팅 [1], [2] 참조)
  • train과 test 아이템이 각 250개씩 있으며, 매 에피소드마다 20개의 아이템이 샘플링되어 agent에 주어지고 agent는 매 step마다 6개의 아이템을 추천.
  • user의 state는 각 아이템에 따른 preference를 embedding한 vector로 구성.
  • user는 preference에 따라 추천된 아이템중 하나를 클릭하거나 아무것도 클릭하지않음.
  • 이때 아이템을 클릭을 할 확률은 아래와 같은 기본 user preference에 CPR metric이 추가로 반영된 score를 계산한 뒤, softmax를 거친 categorical distribution을 따름

$\textrm{score}_{item}=\alpha_{user}*\left<e_u,e_i\right>+\alpha_{metric}*m$

$p_{item}=\frac{e^{s_{item}}}{\sum e^{s_{item}}}$

$R = f_{click\,or\,skip}\left ( p_{item} \right )$

  • 즉, 클릭을 많이 한다는것은 추천한 listwise action의 CPR이 높은것을 함축적으로 내포.
  • 클릭을 할 경우와 skip을 할 경우 각각 1과 0의 reward를 반환.
  • 이 user preference embedding과 item embedding은 user의 이전 step response (클릭 여부)를 반영하여 매 step마다 새롭게 업데이트.
  • user preference vector와 action representation이 모두 주어지는 fully observable 조건.
  • 추천한 아이템 리스트의 CPR은 아이템간의 카테고리 coherence가 높은 동시에 subcategory diversity가 클수록 증가한다고 볼 수 있으며, 이러한 추천을 하려면 샘플링되어 주어진 아이템 간의 관계를 reasoning할 수 있어야 함.
  • CDQN을 사용하여 user session내에서 클릭수를 최대화.

5.3.1 Real-Data Recommender System

  • LINE의 온라인 광고추천 서비스에서 2021년 8월 말 2주간 offline 데이터를 수집하고, 이를 학습하여 9월 초 2주간 offline 테스트를 진행함.
  • user는 '지역/나이/직업'으로 representation되며, 아이템은 'text/이미지/보상포인트'를 feature로 포함.
  • 학습 데이터는 68,775명의 user와 57개 아이템을 포함했으며, 테스트는 82,445명의 user와 58개의 아이템을 포함.
  • reward function은 user의 클릭수와 추천리스트의 CPR값으로 구성.
  • 학습엔 CDQN을 사용했으며 test reward로 평가.

 

6. Experiments

  • 본 연구의 실험파트는 다음 5가지 의문에 대한 분석을 수행 하기위해 설계됨.
    Varying action space 측면에서,
    1) AGILE이 action을 독립적으로 다루거나 고정된 action set을 다루던 기존 접근보다 얼마나 효과적인가?
    2) AGILE의 relational action representation이 action set summary나 action utility score의 계산에 있어서 얼마나 효과적인가?
    3) AGILE의 attention이 action relation을 유의미하게 표현하는가?
    4) AGILE의 GNN에서 attention이 필수적인가?
    5) state-dependent action relation 은 general varying action space task문제를 푸는데 있어 중요한가? 

6.1 Effectiveness of AGILE in Varying Action Spaces

  • action을 독립적으로 다루거나 고정된 action set을 다루는 기존 알고리즘들을 baseline으로 AGILE을 평가하고자함.
  • relational action feature와 summary의 효과를 평가하고자 이에 대한 ablation test를 수행함.

6.1.1 Baselines

  • Mask-Output: 고정된 action space를 가정하고 불가능한 action과 관련된 Q-네트워크나 policy의 output은 masking.
  • Mask-Input-Output: Mask-Output에 더해, 각 action이 활용가능한지 아닌지의 binary정보를 입력에 넣어줌
  • Utility-Policy: action representation과 각 action에 대한 utility policy를사용하여 unseen action을 다루나, graph는 사용하지 않음.
  • Simple DQN: 기본 DQN을 의미. action간의 상호관계를 고려하지않고 가장 높은 Q-value를 가지는 top k를 선택.

6.1.2 Ablations

  • Summary-LSTM: relational action representation을 고려하지 않으며 summary도 GAT 대신 bi-LSTM을 사용.
  • Summary-Deep Set: relational action representation을 고려하지 않으며 summary도 GAT 대신 deep set 아키텍처를 사용. (deep set에 대한 내용은 해당 논문 참조)
  • Summary-GAT: GAT가 summary에만 사용되고 relational action representation은 고려되지 않는 경우.

6.1.3 Results

  • Figure 4.는 baseline의 학습과정(위)과 unseen action이 포함된 테스팅(아래)에서의 결과.
  • 모든 환경에서 AGILE이 명확히 더 나은 결과를 보여주었음. (Real-Data RecSys에선 눈에띄는 차이는 아님)
  • 이로부터, varing action space에서는 활용 가능한 action의 존재와 그 관계를 아는것이 optimal action을 찾는데 매우 중요함을 확인. 

  • Figure 5.는 relation은 고려하지 않고 주어진 활용가능한 action들의 정보 (action summary) 만 고려한 ablation의 test 결과.
  • action의 갯수가 적고 관계성이 단순한 Dig Lava Grid에선 큰 차이가 없음.
  • Recsys에서는 보편적인 아이템을 추천하면 CPR이 높아질 확률이 큰데, action summary만으로도 보편적인 아이템의 을 찾는게 가능하여 관계성 정보까지 사용하는 AGILE대비 큰 차이는 아닌 5-20%의 성능 상승을 보여줌.
  • 반면 CREATE는 기구와 동력원 사이의 관계성이 매우 복잡하고 범위가 넓어서, summary만으론 이 관계를 파악하기가 어려워 여러 환경 중 AGILE이 가장 큰 성능 차이를 보임.

6.2 Does the Attention in AGILE Learn Meaningful Action Relations?

  • Figure 6.은 학습된 agent의 퍼포먼스를 좀더 정성적으로 분석해본 결과.
  • (a)는 CREATE에서의 attention map으로 spring을 선택한 후엔 trampoline에 대한 attention이 매우 강해지는것을 확인할 수 있음.
  • (b)는 Grid World에서의 Suumary-GAT의 attention map으로, 오른쪽으로 가는 action과 분홍색 용암을 퍼내는 action사이에 attention이 매우 높은것을 확인 할 수 있음.
  • (c)는 RecSim에서의 user와의 interaction으로, AGILE이 추천하는 아이템들은 6개 중에 5개를 동일한 가장 보편적인 카테고리인 7을 선택하여 CPR을 최대화 하는것을 확인 가능 함.

 

6.3 Additional Analysis

6.3.1 Importance of Attention in the Graph Network

  • Figure 7.에서는 AGILE의 GAT를 또다른 GNN아키텍처인 graph convolutional network (GCN)으로 바꿨을때의 결과를 비교함.
  • 앞서 살펴본 바와 같이 action간의 관계가 단순한 환경인 Grid World와 RecSys에서 GCN은 optimal 성능을 보여주는것을 확인.
  • action간의 관계가 복잡한 CREATE에서는 GAT를 사용한것 대비 크게 성능이 떨어지는것을 확인 함.
  • RecSym에서도 추가적으로 item간의 pair를 만들어 이를 맞출때만 클릭이 되도록 환경의 action관계성의 복잡도를 올릴 경우, GAT를 쓰는것이 GCN을 쓰는것보다 더 나은 성능을 보여주는것을 확인 함.
  • 저자들은 이에 대해, GAT가 graph를 더 sparse하게 만들어 RL알고리즘의 학습을 쉽게 하는것이며 fully-connected GCN은 이러한 것이 어려울것이라 가정.

6.3.2 Importance of State-Dependent Learning of Action Relations

  • AGILE-Only Action: GAT에서 state를 action과 concat하여 새로운 action representation을 만들지않고 action만 사용한 경우.
  • Figure 7.에서는 AGILE-Only Action의 학습결과도 함께 비교함.
  • Grid World와 CREATE와 같이 state에따라 action의 관계성이 달라지는 환경에서는 state를 concat하지 않을경우 성능이 떨어지는것을 확인함. 
  • 반면 CPR이 user의 state와는 독립적으로 가장 일반적인 category를 아는것만을 필요로 하므로 RecSim에서는 state-dependence의 영향이 적은것을 확인함.

 

7. Conclusion

  • varying action space RL문제에서 action간의 관계성을 활용가능한 AGILE아키텍처를 제안.
  • AGILE은 GAT를 사용하여 action들 사이의 상호의존성을 학습하는것이 가능함을, real-data 추천시스템을 포함한 4개 환경에서 검증 및 확인함.

Author: Hiroki Furuta, Yutaka Matsuo, Shixiang Shane Gu
Paper Link: https://openreview.net/forum?id=CAjxVodl_v 

Code: https://github.com/frt03/generalized_dt

 

Author: Albert Webson, Ellie Pavlick
Paper Link(arXiv): https://arxiv.org/abs/2109.01247

Paper Link(NAACL): https://openreview.net/forum?id=BhGMkxhZrW9 

Code: https://github.com/awebson/prompt_semantics

 

[NAVER AI Lab하정우 박사님 Weekly arXiv 소개내용 참고]

  • Large LM에서 prompt-based learning이 잘되는것이 prompt에 포함된 task instruction의 효과라고 생각해왔는데 이게 정말 그러한지를 실험적으로 분석함.

  • Prompt는 위와 같은 task와 연관되거나 관련없는 등 다양한 카테고리의 템플릿을 사용.

  • 위 표는 실험 결과로서, 체크표시는 instructive한 prompt가 그렇지 못한 prompt대비 통계적으로 유의하게 차이가 나는 경우를 의미.
  • 결과적으로 intrunction과는 성능이 크게 차이 없다는 놀라운 사실과 함께, 마지막 컬럼을 통해 prompt가 있기만 하면 few-shot 성능은 좋아진다는 것을 확인. 즉, LLM은 prompt의 instruction을 이해한것이 아니라는 기존의 생각과 반하는 결과. 특히 GPT-3는 체크표시가 없음.

Author: Hyojin Bahng, Ali Jahanian, Swami Sankaranarayanan, Phillip Isola
Paper Link: https://arxiv.org/abs/2203.17274

Page: https://hjbahng.github.io/visual_prompting/

Code: yet.

 

Author : Pranav Rajpurkar, Emma Chen, Oishi Banerjee & Eric J. Topol
Paper Link : https://www.nature.com/articles/s41591-021-01614-0

 

  • 에릭토폴 교수님과 앤드류응 교수님이 만드신 Doctor Penguin(https://doctorpenguin.com/)에서 2019년 5월부터 지난 2년간의 weekly letter로 정리해온 헬스케어 & 의료분야에서의 AI 동향에 대한 리뷰 페이퍼  

Author : Shishir Rao, Mohammad Mamouei, Gholamreza Salimi-Khorshidi, Yikuan Li, Rema Ramakrishnan, Abdelaali Hassaine, Dexter Canoy, Kazem Rahimi
Paper Link : https://arxiv.org/abs/2202.03487

 

 

  • BEHRT를 unbiased causal inference를 위한 exposure group사이의 feature extraction에 사용
  • BEHRT의 feature를 활용하여 risk ratio(RR)의 초기값을 예측하기 위해 다음 두 task의 loss로 output 모델을 동시에 학습
    1) 기존 counterfactual regression(CFR)방법들의 접근과 같이 propencity와 conditional outcome을 prediction
    2) 마스크 된 환자의 static & temporal covariates를 prediction; Masked EHR modeling(MEM)
  • Cross Validated Targeted Maximum Liklihood Estimation (CV-TMLE)를 사용하여 unbias된 RR을 추론
  • 기존 CFR 방법(Dragonnet, TARNET)들 대비 더 나은 RR 예측 성능을 보여줌
  • 데이터가 많을땐 MEM이 예측성능 향상에 큰 역할을 하지만, 데이터가 작을땐 MEM의 사용보다 casual inference방법론이 성능에 더 큰 영향을 미침 

Author : Pedro A. Ortega, Markus Kunesch, Grégoire Delétang, Tim Genewein, Jordi Grau Moya, Joel Veness, Jonas Buchli, Jonas Degrave, Bilal Piot, Julien Perolat, Tom Everitt, Corentin Tallec, Emilio Parisotto, Tom Erez, Yutian Chen, Scott Reed, Marcus Hutter, Nando de Freitas, Shane Legg
Paper Link : https://arxiv.org/abs/2110.10819

 

 

  • Sequential interaction에 대한 모델을 만들 땐, 단순 prediction loss만으론 self-delusion이 생기는 문제에 대한 DeepMind의 article.
  • Delusion 문제를 다루기 위해 sequential 모델의 observation 분포와 action분포는 분리하여 학습해야하며, action의 probability에 대해선 intervention을 모델링하는 'counterfactual teaching'을 해야 delusion을 해소할 수 있다고 설명.
  • 이 sequential 모델은 $\mathrm{RL}^2$와 같은 memory-based meta learning으로 학습이 가능함.
  • 하지만 중요한 점은 위 설명은 online interaction이 가능한 경우에 대한것이고, offline learning의 경우 아직 open problem임을 설명.

 

개인적인 생각

  • 주 저자들이 Deepmind Safety Analysis이다.
  • 익히 알려진 'causal inference' 문제를 foundation model을 지향하는 관점에서 officially 정리해주었다.
  • Offline learning에선 unobserved confounder가 있을 땐, observation 또한 단순 'factual teaching'기반의 prediction 문제로 학습할 경우 selection bias에 의한 delusion이 생기므로 주의해야한다.

Author  : Takahiro Miki, Joonho Lee, Jemin Hwangbo, Lorenz Wellhausen, Vladlen Koltun, Marco Hutter
Paper Link : https://www.science.org/doi/10.1126/scirobotics.abk2822

 

Contributions

1. Context based meta-RL을 활용하여 예상치 못한 or 노이즈가 강한 환경에서의 4족보행 로봇의 real-world robustness 구현

  • Recurrent belief encoder가 센서로는 측정되지 않는 true dynamics에 대한 정보를 내포하는 latent task belief를 추론하고, 이를 RL policy가 활용
  • Attentional gate를 사용하여 prioprioception(고유수용성 감각; 로봇자체의 움직임에 대한 센서) 과 exteroception(외수용성 감각; 외부환경에 대한 센서)의 multi-modal 센서에대한 상황에 따른 선택적 활용

2. 시뮬레이션의 이점을 활용한 privileged learning기반의 zero-shot sim-to-real transfer learning

  • 현실적인 조건을 가정한 충분히 다양한 물리적 환경을 시뮬레이션상에서 미리 학습
  • 이상적인 조건에서 학습되는 teacher policy와 현실적인 조건에서 teacher policy가 학습한것을 knowledge distillation하는 student policy

 

  • 3 단계에 걸쳐 zero-shot sim-to-real transfer learning을 구성
  • Step1: Teacher policy training
    - 랜덤하게 생성된 지형에서 명령으로 준 랜덤 target velocity와의 차이를 reward로 PPO알고리즘에 주어 학습
    - Teacher policy의 입력으로는 1. 속도 command, 2. proprioception센서 정보, 3. exteroception 센서 정보 4. previleged 정보 (ex. 마찰력과 같은 환경의 true dynamics)
    - 시뮬레이션을 활용하여 이상적인 정보를 줌으로써, RL알고리즘이 충분히 optimal에 가까운 policy를 학습하도록 유도  
  • Step2: Student policy training
    - Student policy의 입력으로는 1. 속도 command, 2. proprioception센서 정보, 3. 노이즈가 들어간 exteroception 센서 정보
    - 충분하지 못한 정보에서 scratch로 좋은 RL policy를 학습하기보다, 이미 학습한 좋은 tearch policy를 supervised learning으로 distill하여 효율적으로 학습; privileged learning
    - Prorioception정보와 extroception정보로부터 unobervable state에대한 belief를 추론하기위해 recurrent belief state encoder를 제안
    - Belief encoder가 좋은 latent space를 학습하도록 하기위해, previleged 정보와 true exteoception정보에 대한 reconstruction loss를 사용
  • Step3: Deployment
    - 실제 로봇에 학습한 student policy를 decoder를 제외하고 deploy
    - Context based meta-RL인 만큼 fine tunning이나 optimization 없이 실시간으로 real-world에 adaptation가능

 

  • Exteroception정보는 경우에 따라 틀리거나 얻지못할 수 있으므로, 필요에 따라 exteroception정보에서 의미있는 정보를 선별하여 쓰기위하여 attention gate를 사용한 gated encoder 적용
  • Attention gate는 최종 belief state에 어느정도의 exteroception정보를 담을지를 조절

 

  • Autoencoder를 사용하여 representation learning을 한 만큼, decoder를 사용하여 internal belief를 시각화 가능
  • 아래 그림에서 빨간 점은 policy에 입력으로 들어가는 실제 지면높이 정보 파란 점은 decoder에 의해 복원된 지면높이에대한 agent의 belief
  • A) 스펀지 장애물을 밟기전엔 지면 높이가 높다고 생각하고 있다가, 스펀지를 밟자으면서 들어오는 시계열의 푹신한 반응정보로부터 encoder는 실시간으로 평평한 지면 인것으로 belief가 변경
  • B) 투명한 장애물을 exteroception 센서가 인식못해 평평한 지면이라고 생각하다가, 상자를 밟는 순간 지면의 높이가 있는것으로 belief가 변경
  • D) 센서가 완전히 가려진 상태에서도 지면이 경사졌다고 판단다는 belief가 형성되며, 이는 사람이 걸으며 주변환경이 어두워질 경우 시각에서 체성감각으로 주의를 옮겨 지형지물을 판단하는것과 유사
  • E) 미끄러운 지면의 장애물을 걸을경우, 미끄러지는 만큼의 연장된 너비의 지면에 대한 belief가 형성하는 동시에 마찰의 변화 역시 추정 

 

  • 실제 real-world 환경에서도 넘어짐 없이 robust하게 동작하는 영상

 

 

Author: Ioana Bica, Ahmed M Alaa, Mihaela van der Schaar
Paper Link: https://arxiv.org/abs/1902.00450

Talk in ICML2020: https://icml.cc/virtual/2020/poster/6131

Talk in van der Schaar Lab's Yutube Channel:  https://www.youtube.com/watch?v=TNPce1zd6rE 

Code: https://github.com/ioanabica/Time-Series-Deconfounder

 

 

0. Abstract

  • 의료분야에서 treatment effect를 추론하는것은 중요하지만, 지금까지의 추론 방법들은 모두 'no hidden confounder'라는 비현실적이고 결과적으로 추론에 bias를 야기하는 가정을 전제로 함.
  • 본 연구에서는 시간에 따른 다중 치료 환경에서 multi-cause hidden confounder가 존재할때의 treatment effect를 추론하기위한 Time Series Deconfounder를 제안함.
  • Time series Decounfounder는 multitask output의 RNN을 factor model로 사용하여 multi-cause unobserved confounder를 대체하는 latent variable을 추론하고 이를통해 casual inference를 수행함.  
  • 이론적 분석과 함께 시뮬레이션 및 실제 MIMIC III데이터를 사용하여 알고리즘의 효과성을 검증함.

1. Introduction

  • 연속적으로 처방된 치료에 따른 환자 개인의 치료 효과를 예측하는것은 매우 중요한 문제임.
  • 최근 이러한 정보를 담고있는 observational 데이터 역시 빠르게 증가하고 있음.
  • 하지만 기존의 방법들은 모든 confounder가 관측가능하다는 대체로 비현실적인 상황을 가정하고있어 예측에 bias가 생김.
  • 예를들어 암의 진행에대한 항암제의 효과를 예측할 때 환자의 약에대한 내성형성이나 누적되는 독성을 고려하지 않는것은 예측결과에 bias를 초래.
  • 하지만 내성이나 독성은 관측이 어렵고 관측이 되더라고 후향적인 EHR과 같은 후향적인 관측데이터에는 기록되어있지 않은 경우가 대부분.
  • 본 연구에선 Wang & Blei (2019a)의 연구에서 고안된 'static 셋팅에서의 multiple treatment를 활용하여 hidden confounder를 deconfounding하는 방법'을 개선하여 longitudinal 셋팅에서의 time-varying hidden confounder를 deconfounding하는 Time Series Deconfounder를 제안함.
  • 시계열 환경에서 unobserved confounder의 대체재로서 letent variable을 학습하는 첫번째 시도.

 

2. Related Work

  • 시간에 따라 변하는 치료에 대한 Potential outcomes
    - 지금까지 시계열 데이터에 대한 counterfacutal inference로는 G-formula, G-estimation, MSM, R-MSN, balaced representation 등이 있었지만 모두 hidden confounder가 없다고 가정.
    - 연속성 데이터에대한 treatment effect 연구들도 있어왔으나 여기선 이산 환경을 다룸.
    - Unmeasured confounder에 대한 potential impact를 평가하기위한 sensitivity anlysis방법들도 고안되어옴.
  • Hidden confounder 추론을 위한 latent variable 모델
    - Multi-cause 환경에서 hidden confounder를 추론가능한 latent variable로 대체하고 추론된 latent variable로 causal inference를 수행하는식의 deconfounder 접근은  Wang & Blei (2019a, link)에서 제안된 바 있음.
    - 해당 논문은 static treatment 문제를 다루고 있으나, 본 논문은 이와 달리 time-varying treatment문제를 다루기 위해 RNN을 factor model로 사용하는 deconfounder 구조를 제안함. 

 

3. Problem Formulation

  • $\mathbf{X}_{t}^{(i)} \in \mathcal{X}_{t}$: random variable, (환자 $i$에 대한; 이후 생략) time-dependent covariates
  • $\mathbf{A}_{t}^{(i)}=\left[A_{t1}^{(i)}{\cdots}A_{tk}^{(i)}\right]\in\mathcal{A}_{t}$: 시간 $t$에서의 가능한 $k$가지 treatments
  • $\mathbf{Y}_{t+1}^{(i)}\in\mathcal{Y}_{t}$: 관측된 outcomes
  • $\tau^{(i)}=\left\{{\mathbf{x}_t^{(i)}},{\mathbf{a}_t^{(i)}},{\mathbf{y}_{t+1}^{(i)}}\right\}_{t=1}^{T^{(i)}}$: 이산시간 $T^{(i)}$ 동안 수집된 trajectory 샘플
  • $\mathcal{D}=\left\{{\tau^{(i)}}\right\}_{i=1}^{N}$: $N$명의 환자에 대한 EHR 데이터
  • $\overline{\mathbf{A}}_t=(\mathbf{A}_1,\cdots,\mathbf{A}_t)\in\overline{\mathcal{A}}_t$: 시간 $t$까지의 treatment history
  • $\overline{\mathbf{X}}_t=(\mathbf{X}_1,\cdots,\mathbf{X}_t)\in\overline{\mathcal{X}}_t$: 시간 $t$까지의 covariates history
  • $\mathbf{Y}(\overline{a})$: 가능한 treatment course $\overline{a}$에 대한 potential outcome (factual & counterfactual)
  • 아래의 Individualized treatment effect (ITE), 즉 환자 개인의 covariate history와 treatment history가 주어졌을때의 potential outcome, 을 추론하는것이 이 연구의 목표

$\mathbb{E}[\mathbf{Y}(\overline{a}_{\geq{t}}){\vert}\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]$

  • 아래 세 가지 가정 하에서는 bias가 생기지 않아 위 potential outcome에 대한 추론은 우항의 관측된 outcome에 대한 regression과 동치

 $\mathbb{E}[\mathbf{Y}(\overline{a}_{\geq{t}}){\vert}\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]=\mathbb{E}[\mathbf{Y}{\vert}\overline{a}_{\geq{t}},\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]$

  Assumption 1. Consistency

  Assumption 2. Positivity (Overlap)

  Assumption 3. Sequencial strong ignorability (no hidden confounders)

$\mathbf{Y}(\overline{a}_{\geq{t}}){\perp\!\!\!\!\perp}\mathbf{A}_t\vert\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_{t}$

  • 하지만 세번째 가정은 좌변의 counterfactual로 인해 테스트가 불가능하며 현실적이지 않아 여기서는 hidden confounder가 있는 보다 현실적인 문제를 다루고자 함.
  • 따라서 위에서 언급한 동치는 성립하지 않음.

$\mathbb{E}[\mathbf{Y}(\overline{a}_{\geq{t}}){\vert}\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]\neq \mathbb{E}[\mathbf{Y}{\vert}\overline{a}_{\geq{t}},\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]$

  • 대신 Wang & Blei (2019a)의 접근을 확장하여 시간에 따른 다중 treatments를 활용한 sequencial latent variable $\overline{\mathbf{Z}}_t=(\mathbf{Z}_1,\cdots,\mathbf{Z}_t)\in\overline{\mathcal{Z}}_t$ 을 추론하고 관측되지 않은 confounders로서 대체하고자 함.

 

4. Time Serise Deconfounder

  • 본 연구에서 제안하는 Time Series Deconfounder의 본질적인 아이디어는 multi-cause confounder로 인한 treatment들 사이의 종속성이 있다는것.
  • 이 종속성을 활용하여 시간에 따라 바뀌는 treatment로 부터 hidden confounder를 추론함.

4.1. Factor Model

  • Time Series Deconfounder는 시간 $t$이전까지의 history $\overline{\mathbf{h}}_{t-1}$로부터 시간 $t$에서의 unobserved confounder를 대체할  letent variable $z_t$을 추론하는 factor model $g$을 가짐.

$\mathbf{z}_t=g(\overline{\mathbf{h}}_{t-1})$

,where  $\overline{\mathbf{h}}_{t-1}=(\overline{\mathbf{a}}_{t-1},\overline{\mathbf{x}}_{t-1},\overline{\mathbf{z}}_{t-1})$

  • 이를 그래프로 나타내면 위 그림$(a)$와 같으며, latent variable  $\mathbf{z}_t$는 시간 $t$에서의 treatment들에 대한 multi-cause unobserved confounder를 대체하는 역할로서 나타낼 수 있음. 
  • Multi-cause unobserved confounder로 인한 treatment들 사이의 종속성으로 인해 factor model을 사용하여 latent variable의 sequence $\overline{\mathbf{Z}}_t$를 추론 할 수 있음.
  • 이때, latent variable $\mathbf{Z}_t$는 모든 multi-cause confounder를 내포하고 있음을 보장 (treatment사이의 종속성을 활용한 귀류법).
  • 즉, 위 그림$(b)$에서와 같이 또다른 multi-cause confounder $V_t$는 존재하지 않음.
  • 하지만 $L_t$와 같은 single-cause confounder가 없는것은 보장할 수 없으므로 다음을 새롭게 가정하고, 위에서 언급한 기존 방법들에서 사용한 세가지 가정 중 세번째 가정을 대체.

  Assumption 3. Sequential single strong ignorability (no hidden single cause confounders)

$\mathbf{Y}(\overline{a}_{\geq{t}}){\perp\!\!\!\!\perp}{A}_{tj}\vert\mathbf{X}_{t},\overline{\mathbf{H}}_{t-1}$

  • 물론 이 가정 역시 여전히 테스트가 불가능하지만, 관측가능한 treatment의 갯수가 증가함에 따라 hidden confounder가 하나의 treatment에만 영향을 줄 가능성은 급격히 줄어듬.
  • Wang & Blei (2019a)에 따르면 $\mathbf{Z}_t$의 차원이 treatments의 갯수보다 작을경우 가정2. Positivity또한 실질적으로 만족가능.
  • Fitting된 factor model이 얼마나 정확하게 validation set 환자의 treatment분포를 예측하는지를 평가하기위해 predictive check로서 각 시간 $t$에서 $M$개의 예측 샘플과 실제treatment 사이의 $p$-value를 아래와 같이 계산.

$\frac{1}{M}\sum_{i=1}^{M}\mathbf{1}(T(a_{t,rep}^{i})<T(a_{t,val}))$

,where $T(a_t)=\mathbb{E}_z[\mathrm{log}\,p(a_t{\vert}Z_t,X_t)]$ is test statistic and $\mathbf{1(\cdot)}$ is indicator function

  • Fitting이 잘 된 경우 $p$-value는 0.5에 가까움.

4.2. Outcome Model

  • Factor model이 잘 fitting된 다음 스텝으로 Time Series Deconfounder는 아래의 좌변의 시간에 따른 individualized treatment effect를 추론하기위한 outcome model을 우변과 같이 fitting.

$\mathbb{E}[\mathbf{Y}(\overline{a}_{\geq{t}}){\vert}\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]=\mathbb{E}[\mathbf{Y}{\vert}\overline{a}_{\geq{t}},\overline{\mathbf{A}}_{t-1},\overline{\mathbf{X}}_t]$

  • Time Series Deconfounder의 uncertainty는 factor model에서 시간에 따라 샘플된 sequential latent variable $\hat{\bar{\mathbf{Z}}}_t=(\hat{\mathbf{Z}}_1,\cdots,\hat{\mathbf{Z}}_t)$ 을 다시 반복 샘플하여 구한 각각의 outcome들의 variance로 판단함.
  • 만약 treatment effect가 부정확하여 non-identifiable할 경우엔 이 variance가 커짐.
  • 또한 본 연구에서 제안하는 hidden confounder문제를 다루기 위해 latent variable를 추론하는 접근은 treatment effect의 bias를 명백히 낮추지만, Wang & Blei (2019a)에 따르면 hidden confounder가 없는 상황에선 기존 방법 대비 variance가 상대적으로 커져 free lunch는 아님. 

 

5. Factor Model over Time in Practice

  • 시계열 문제를 다루고 있으므로 기존의 PCA나 Deep Exponential Families을 사용하는 대신 아래 그림과 같이 RNN, 여기서는 특히 LSTM을 factor model로서 사용함.

  • 즉, RNN을 사용하여 환자의 시간 $t$까지의 history로 부터 시간 $t$에서의 latent variable을 추론.

$\mathbf{Z}_1=\mathrm{RNN}(\mathbf{L})$

$\mathbf{Z}_t=\mathrm{RNN}(\overline{\mathbf{Z}}_{t-1},\overline{\mathbf{X}}_{t-1},\overline{\mathbf{A}}_{t-1},\mathbf{L})$

  • RNN의 출력사이즈는 $D_Z$이며 $\mathbf{L}$은 학습가능한 initial paramter.
  • RNN에서 추론된 Latent variables $\mathbf{Z}_t$와 관측된 covariates $\mathbf{X}_t$에 조건부 독립인 treatment $\mathbf{A}_t=[A_{t1},\cdots,A_{tk}]$를 추론하기위해서 treatment 개수 $k$만큼의 single FC MLP레이어를 RNN의 출력단에 multitask output으로 붙임.

$A_{tj}=\mathrm{FC}(\mathbf{X}_t, \mathbf{Z}_t;\theta_j)$

  • Binary treatment의 경우엔 출력레이어에 sigmoid activation을 사용함.
  • Factor model의 확률적인 특징을 구현하기위해서 위 그림의 별이 그려진 부분에 $variational\,dropout$(GAL & Ghahramani, 2016a)을 사용하였고, 이에 따른 latent variable의 샘플링이 가능해짐.
  • 위와 같은 구현으로 RNN으로 하여금 $\overline{\mathbf{X}}_t$, $\overline{\mathbf{Z}}_t$ 및 $\overline{\mathbf{A}}_t$사이의 복잡환 관계를 학습도록 할 수 있지만, 이 과정에서 predictive check가 반드시 필요하다는것에 주의.

 

6. Experiments on Synthetic Data

  • 제안한 Time Series Deconfounder를 검증하고자 합성데이터를 사용함.
  • 실제 데이터를 사용한 검증은 hidden confounder를 알 수 없으므로 불가능.

6.1. Simulated Dataset

  • 5000명의 환자에 대한 20~30 스텝의 가상데이터를 treatments, covariates, hidden confounders가 서로 영향을 미치는 $p$-order autoregressive 과정으로 생성함 (자세한 수식은 논문 참조).
  • 그리고 outcome은 covariates와 hidden confounder의 함수가 되도록 생성.

6.2. Evaluating Factor Model using Predictive Checks

  • 제안한 factor model 아키텍처가 treatment의 분포를 잘 학습하는지 확인하고자 합성데이터에 대해 아래 세 가지 모델의 predictive check를 수행함.
    1. 제안한 factor model; RNN + Multitask FC output (초록)
    2. RNN대신 MLP를 사용한 factor model (파랑)
    3. Multitask FC layer대신 단일 FC layer를 사용한 factor model (보라)

  • 실험 결과 RNN대신 MLP를 사용할 경우 시간이 지남에 따라 지속적인 distribution mismatch가 생김.
  • Multitask output은 treatment distribution을 파악하는데 도움은 되나 큰 영향을 주는것은 아님을 확인.
  • 즉, factor model에 RNN아키텍처를 사용하는것이 hidden confounder의 시간 의존적인 특성을 캡쳐하는데 있어 중요하며, 현재 스텝의 covariates와 confounders가 잘 명시될 경우 treatment distribution을 학습할 수 있다고 결론.

6.3. Deconfounding the Estimation of Treatment Responses over Time

  • Time Series Deconfounder가 confounder에 의한 bias를 잘 deconfounding하는지를 다음의 두 outcome model을 사용하여 검증함. 
  1. Standard Marginal Structural Models (MSMs)
    - Logistic regression으로 구한 inverse probability of treatment weighting(IPTW)을 사용하여 confounder가 balance된 pseudo-population을 생성하는 단계와, 이렇게 생성된 pseudo-population으로부터 treatment reponse를 linear regression하는 단계의 두 가지 스텝으로 구성된 selection bias 대응방법.
    - 이름에서 'marginal'은 counfounder control의 의미이며, 'structural'은 potential outcome framework를 의미함.
    - MSMs에 대한 자세한 내용은 다음 두 강의를 참고
    (https://www.youtube.com/watch?v=7NjIQTzADgQ)
    (https://www.coursera.org/lecture/crash-course-in-causality/marginal-structural-models-EUpei)
  2. Recurrent Marginal Structural Networks (R-MSNs; Lim et al., 2018)
    - MSMs와 접근은 같지만 RNN을 사용하여 propensity score를 추론하고 treatment response 역시 RNN을 사용하여 추론하는것이 차이.
    - $\mathrm{RNN}(\overline{\mathbf{X}}_t,\overline{\mathbf{Z}}_t,\overline{\mathbf{A}}_t)$ 와 같이 구현하며, RNN을 사용하여 추정한 propensity weights에 따라 weight를 각 환자에 주어 loss함수를 계산.
  • 평가를 위해 한 스텝 다음의 treatment response를 예측하는 테스크를 사용하였으며, 두 outcome model에 대한 자세한 분석을 위해 아래 5가지 경우를 비교
    1) Confounded: hidden confounder를 고려하지 않고 관측데이터를 그대로 사용한 경우. 
    2) Deconfounded ($D_z=1$): 실제 hidden confounder의 크기인 1과 동일한 크기의 latent variable $\hat{\overline{\mathbf{Z}}}_t$를 사용한 경우.
    3) Deconfounded ($D_z=5$): 실제 hidden confounder의 크기인 1과 다른, 크기 5의 latent variable $\hat{\overline{\mathbf{Z}}}_t$를 사용한 경우.
    4) Deconfounded w/o $X_1$: Assumption 3를 위반한 경우로 single cause confounder $X_1$를 covariate에서 제거하여 hidden confounder로 가정한 경우.
    5) Oracle: 합성데이터에서의 실제 ground truth hidden confounder $\overline{\mathbf{Z}}_t$를 outcome 모델에 넣어준 경우. 

  • 위 결과 그래프를 보면 Deconfounded에서 Confounded보다 Oracle과 유사한 결과를 보여주어 Time Series Deconfounder가 treatment response에 대해 unbiased estimation을 하는것을 확인함. 
  • Deconfounded 두가지 경우, 서로 크게 차이나지 않는데서 hidden counfounder 크기에 대한 model misspecification과 관계없이 robust한 결과를 확인함.
  • Single hidden confounder가 있을 경우엔 bias를 해결하지 못하는데서 Assumption3가 중요하단것을 확인함. 
  • 5가지 경우 모두 RNN기반의 R-MSNs가 MSMs보다 뛰어난 정확성을 보여줌.

 

7. Experiments on MIMIC III

  • Time Series Deconfounder를 실제 데이터에 대한 검증을 위해 EHR 오픈데이터인 MIMIC III의 6256명의 환자 데이터에 적용함.
  • 특히 폐혈증 환자에서 항생제, 혈압상승제, 기계식 호흡장치의 총 3 가지의 treatment가 백혈구 개수, 혈압, 산소포화도의 각 3 가지 response에 어떻게 영향을 미치는지를 실험.
  • 실제 데이터인만큼 폐혈증 외의 질병에 대한 cormorbidity나 몇 lab test가 기록에 없다던지의 hidden confounder가 존재하며, Oracle 경우를 확인 불가능함.

  • 3 가지 response실험 모두에서 Confounded보다 Time Series Deconfounder를 적용할 경우 정확도가 상승하는것을 확임함.
  • 합성데이터에서와 마찬가지로 RNN기반의 R-MSNs가 MSMs보다 뛰어난 정확성을 보여줌.
  • 추후 의료진의 의견을 참고한 심층된 검증 필요.

 

8. Conclusion

  • 관측된 시계열 환자 데이터에서 individualized treatment effect를 추론하는 기존 방법들에선 모두 hidden confounder가 없다는 가정을 했으나, 시계열 데이터에선 시간에 따라 환자의 상태가 계속 바뀌는데가 treatment를 결정하는 복잡도가 올라가 특히나 더 비현실적인 가정임.
  • 이에 본 연구에선 hidden confounder를 대체가능한 latent variable을 추론하는 Time Series Deconfounder를 제안하고 RNN, multitask output, variational dropout을 사용하여 구현함.
  • 합성데이터와 실제데이터를 사용하여 multi-cause hidden confounder가 있을때의 Time Series Deconfounder의 bias 제거 효과를 보여줌.

 

9. Appendix

  • (Table 3.) Hidden confounder가 treatment와 outcome에 미치는 영향이 커질수록, 더 큰 capacity의 모델이 필요.
  • (D.2) RNN기반의 treatment effect estimation이 시간에 따라 변화하는 treatment policy에 보다 robust.
  • (Figure 6.) 실제 hidden confounder의 갯수와 같게 $D_Z$를 설정하거나 overestimate할 때 treatment response에 대한 예측도가 향상함.

Author: Ioana Bica, Daniel Jarrett, Alihan Hüyük, Mihaela van der Schaar
Paper Link: https://openreview.net/forum?id=h0de3QWtGG 

Talk in ICLR2021: https://iclr.cc/virtual_2020/poster_BJg866NFvB.html

Rating: 8, 7, 6, 5

Author  : Yaobin Ling, Pulakesh Upadhyaya, Luyao Chen, Xiaoqian Jiang, Yejin Kim
Paper Link : https://arxiv.org/abs/2109.12769

 

박지용 교수님께서 기획하신 2021년 인과추론 써머세션에서 김예진 교수님께서 강의해주신 Heterogenous Treatment Effect Estimation using ML 세션을 정말 재밌게 들었는데, 이에 대한 튜토리얼 및 벤치마크 논문이 워킹페이퍼로 공개되었다. 논문을 보면 필요한 용어에 대해 상세히 정의해두었고 특히 예시를 정말 잘 활용하고 있어서 처음 causal inference를 접하는 사람들에게 너무 좋은 내용이다.

 

아래는 줌으로 실강을 듣고나서도 여러번 반복해서 더 들은 김예진 교수님의 강의.  이런 좋은 강의를 한국어로 들을 수 있다니 두 교수님께 감사하다. 

 

 

 

Author  : Andrew Forney, Elias Bareinboim
Paper Link : https://ojs.aaai.org//index.php/AAAI/article/view/4090

 

 

Author  : Ioana Bica, Ahmed M Alaa, James Jordon, Mihaela van der Schaar
Paper Link : https://openreview.net/forum?id=BJg866NFvB 

Talk in ICLR2020: https://iclr.cc/virtual_2020/poster_BJg866NFvB.html

Rating: 8, 6, 6

 

Author  : Anonymous (Naver Labs Europe이지 않을까)
Paper Link : https://openreview.net/forum?id=1L0C5ROtFp

 

Rating: 5, 6, 10

Author : Junyuan Shang, Tengfei Ma, Cao Xiao, Jimeng Sun
Paper Link : https://arxiv.org/abs/1906.00346

Code: https://github.com/jshang123/G-Bert

Author : Laila Rasmy, Yang Xiang, Ziqian Xie, Cui Tao & Degui Zhi 
Paper Link : https://www.nature.com/articles/s41746-021-00455-y

Code: https://github.com/ZhiGroup/Med-BERT

 

Contributions

  1. npj digital medicine
  2. 미국내 600여개 병원의 2천만명 환자 EHR (Cerner)
  3. BERT기반의 MLM사용 
  4. 추가적으로 진단 코드 이상의 context를 학습하기위한 Domain-specific pretraining task로서 입원기간을 예측
  5. Fined-tunned task로는 당뇨병환자의 심부전 & 췌장암 예측
  6. 다른 task, 적은 데이터, 다른 EHR DB(Truven) 를 사용한 Pretrained EHR모델의 real-world generalizability 검증 의의

 

 

  • Med-BERT 학습 아키텍처

 

  • BERT기반의 다른 EHR데이터 활용 논문들과의 비교 및 핵심 내용 요약

Author : Yikuan Li, Shishir Rao, José Roberto Ayala Solares, Abdelaali Hassaine, Rema Ramakrishnan, Dexter Canoy, Yajie Zhu, Kazem Rahimi & Gholamreza Salimi-Khorshidi 
Paper Link : https://www.nature.com/articles/s41598-020-62922-y

Code: https://github.com/deepmedicine/BEHRT

 

 

Contributions

  1. BERT+EHR = B EHR T
  2. LM기반 EHR관련 연구 중 인용수가 100이 넘어간 대표 논문
  3. 환자의 진단코드(301개) 진행을 예측하기 위한 pretrained모델 제안

Limitations

  1. EHR은 interaction 데이터 임에도 진단코드와 나이만을 사용한 seqeunce modeling

 

 

  • 5회 이상 방문한 160만명의 데이터 사용
  • 301개 클래스의 진단코드와 나이만을 학습을 위한 시계열 데이터로 사용
  • 진단코드가 없는 방문은 데이터에서 제외

 

  • Tabular sequence data의 embedding은 아래와 같이 진단코드에 나이(병인+방문간격 역할) 와 positional encoding(방문순서)와 segment(방문 구분)를 추가로 포함하여 수행
  • Pre-train은 BERT와 같이 MLM을 사용하여 환자의 시계열 데이터의 중간에 마스크된 질병을 예측하도록 수행
  • 모델이 병의 진행에 대한 전반적인 학습을 했는지를 검증하기 위해 1) 다음 방문에서의 진단, 2) 다음 6개월 내의 진단, 3) 다음 12개월 내의 진단을 예측하는 downstream task를 수행

 

  • 학습된 embedding을 시각화 해본 결과, 남성질병과 여성질병의 거리가 먼 것을 확인
  • 또한 빈도가 낮은 질병에 대해 가장 가깝게 embedding된 질병을 실제 의료진의 의견과 비교한 경우 75.7% 일치
  • 이로부터 저자들은 BEHRT가 질병의 latent characteristics를 잘 이해했다고 판단 

 

  • Predictive downstream tasks에서도 기존 모델들 보다 높은 성능을 확인

Author : Jose Roberto Ayala Solares, Yajie Zhu, Abdelaali Hassaine, Shishir Rao, Yikuan Li, Mohammad Mamouei, Dexter Canoy, Kazem Rahimi, Gholamreza Salimi-Khorshidi
Paper Link : https://arxiv.org/abs/2107.12919

Author : Zeljko Kraljevic, Anthony Shek, Daniel Bean, Rebecca Bendayan, James Teo, Richard Dobson
Paper Link : https://arxiv.org/abs/2107.03134

 

 

 

Author  : Kyle Aitken, Vinay V Ramasesh, Yuan Cao, Niru Maheswaranathan
Paper Link : https://arxiv.org/abs/2110.15253

 

+ Recent posts