[NSDI 2021] ATP: In-network Aggregation for Multi-tenant Learning

Abstract

  • DDL興起,DDL需要將gradient分開做compute並在計算完之後做aggregate
  • 目前最新的技術已經有人將瓶頸從“分開計算”變成“worker跟parameter server (PS)之間的溝通(網路傳輸)”
  • ATP試圖減少“worker跟parameter server (PS)之間的溝通”時間

1 Introduction

  • 本篇論文的目標就是要試著將從in-network multi-switch aggregation中獲得的benefit最大化並平等分配給複數個不同的DT jobs

2 Background and Motivation

2.1 Preliminaries

  • PS Architecture
    • 由PS做aggregation再broadcast回所有worker
  • In-network Aggregation
    • 由switch做aggregation再broadcast回所有worker

2.2 In-Network Aggregation Service

  • Rethinking reliability
    • 透過switch做aggregation的方法有一缺點:switch會drop原封包然後forward做完所有gradient的aggregation後的封包給worker
    • 這在worker眼中有可能會認為原本的封包掉包了
      • 但處理這件事很簡單,在worker上做些設定讓它不要認為這狀況是掉包即可
    • 但如果今天真的發生掉包了呢?
  • Rethinking congestion-control
    • 由於改用網路傳輸做aggregate,那勢必會受限於網路資源
    • 我們要怎麼盡量平等分配這些網路資源?
    • 在一個沒有”end-to-end”概念的網路下發生壅塞時又要怎麼實作congestion control?

3 Design

3.1 ATP Overview

  • ATP是一個可以同時達成dynamic aggregation和best-effort aggregation的protocol
    • Dynamic aggregation在這邊指的是switch的resource要能夠動態地被重複利用
      • 意思就是當worker拿回aggregate完的gradient packet後要確保他前面在switch上所占用的資源要能被釋放並給其他DT job使用
    • Best-effort aggregation則是指當switch上因競爭過於激烈而資源不足時要能夠優雅地回退到由end-host做aggregation而不造成額外支出

3.2 ATP Infrastructure Setup

3.3 Data Structure

    • Bitmap0是該worker位在所連到的switch的哪個port(one-hot encoding)
    • Bitmap1則是連到worker的switch又位在它上層switch的哪個port
    • Fanindegree0則是第一層switch所連接的worker數量
    • Fanindegree1則是第二層switch所連接的第一層switch及worker數量
    • Edgeswitchidentifier則是看packet目前在哪一層switch,在第一層就是0,在第二層就是1
    • Bitmap紀錄這個aggregator目前已經由哪些port收到gradient packet
    • counter是紀錄目前已經收到幾個gradient packet
    • ECN是拿來做congestion control的

3.4 Inter-rack Aggregation

3.5 Switch Logic

  • Aggregator Allocation
    • 收到來自worker傳的gradient packet之後,switch會根據packet所夾帶的job id、sequence number、index、bitmap0跟1、Edgeswitchidentifier等資訊決定它要forward這個packet還是將相應的aggregator分配給這個packet
  • Gradient Aggregation
    • 如果aggregator可分配給這個gradient packet,那就將aggregator的counter++,更新一下bitmap欄位
    • 如果counter的數字比fanindegree小表示還有worker還沒送packet,那就將目前送進來的這個packet drop掉,直到counter=fanindegree
    • 當counter=fanindegree,switch就會將此封包夾帶的value改成aggregate後的gradient然後更新一下header的資訊,最後forward出去給PS
  • Aggregator Deallocation Using Parameter Packets
    • 收到來自PS的parameter packet(可以從header上”isACK”這個bit判斷)時,switch會將其視作PS有成功接收並記錄到該packet,而且也知道這包封包不是要來做aggregate的封包,這時就會釋放掉它剛剛在做aggregate時占用的aggregator,並將其forward回worker

3.6 End Hsot Logic

  • Worker Pushing Gradients
    • 就是worker傳gradient packet出去
    • 只是在傳出去之前要先將gradient乘以一個大數字使其從floating point變成32-bit的integer (之所以這樣做是因為switch不支援浮點數運算)
  • PS Updating Parameters
    • 當switch沒辦法aggregate gradient時,由PS處理switch forward過來的packet
    • 收到packet後會將gradient value除以剛剛worker乘的大數字使gradient變回floating point
    • PS會從”bitmap”這個field檢查這個packet是否已被aggregate過,若否則aggregate,若是則忽略該packet
    • 但是也有可能發生一種狀況是當某個job的第一封gradient packet送到switch時,switch的aggregator正在busy所以將packet forward給PS,但過沒多久aggregator的resource清出後同個job的其他gradient packet才送到switch
      • switch仍然會把這些晚到的gradient packet收下,就變成switch跟PS互相等存在對方那邊的那封packet
      • 這狀況會因為PS太久沒把aggregate完的packet回傳而使worker以為掉包觸發重傳機制
  • Worker Pulling Parameters
    • worker乖乖等做完aggregation的packet從PS傳回來
    • 如果某個sequence過號太久都沒等到就視為掉包觸發重傳機制

3.7 Reliability and Congestion Control

  • Dealing with Packet Retransmission
    • 如果某個sequence過號太久都沒等到就會被視為掉包而觸發重傳機制
    • 重傳之前worker會將封包的header中“resend”這個field mark起來
    • switch會透過index、job id檢查aggregator裡面有沒有存屬於這個job的gradient
      • 如果有且目前這包gradient還沒被aggregate過 (有沒有被aggregate過可以從bitmap看出來)
        • 先aggregate value然後更新bitmap,接著forward這個有可能還未完成的result到PS (掉包的aggregation最終會在PS完成),最後釋放它所佔用的aggregator
      • 如果aggregator裡面有存屬於這個job的gradient且目前重傳的這份packet是曾經被aggregate過的
        • 直接捨棄目前這份packet然後forward result給PS,最後釋放它所佔用的aggregator
      • 其他
        • 直接forward給PS
    • 最終的aggregation都會在PS中完成
  • Memory Leak
    • worker送出所有gradient給PS之前job不正常地終止了就可能造成aggregator一直被占用著
    • 解法是設定一個timeout的value,當新進的封包(不一定要是同job的封包)所夾帶的timestamp減掉占用aggregator的封包的timestamp已經超過的timeout value的話,整個aggregator就會被清掉

3.8 Dealing with Floating Points

  • 剛剛前面提過,worker送packet前先將gradient乘以一大數字,PS再除以此數字將其還原