동시성 모델 추론 입문
Concurrency in model inference
Feb 04, 2024
서론
최근 ML engineer로서의 직무를 커리어를 시작하였고, 자연스럽게 concurrency를 커버하는 추론 방법론에 대해 궁금해졌다. 나중에 커리어와 관련된 글도 작성하면서 이야기를 하겠지만, ML 엔지니어로서 내가 가져가야 하는 커리어는 추론 엔진을 개발하는 것이라는 생각이 들었다. 좀 더 풀어서 이야기하자면, torchserve bentoml 같은 tool도 있겠지만, 모델 architecture에 맞춰서 더 최적화를 진행할 수 있는 엔진을 개발해야 한다는 것이다.
우선 그런 방대한 이야기에 진입하기 전에, 이번 글에서는 concurrent inference에 관련해서 이야기해보겠다. 본 내용은 참고한 제일 하단 자료를 번역한 내용들을 주로 다뤘다.이번 글에서 나오는 모델은 image detection 으로 보이고, 여러 개의 detector를 동시에 사용하는 방법을 기술할 예정이다.
Concurrent inference - 2가지 방법
- Batched Concurrency
- 하나의 모델 instance에 여러 개의 request를 batch 형태로 만들어서 추론하는 방법이다.
- 이 방법은 batch로 만드는 과정에서 trade-off가 발생한다.
- batch가 만들어지기 전까지 기다리면 latency는 길어지기 때문이다.
Dataloader에서 item들을 꺼내서 실제로 batch 형태로 만들어진 input을 Tensor로 감싸서 모델에서 쓰인다.
- Multiprocess Concurrency
- 해당 방법은 여러 개의 모델을 띄우고, 각각의 input은 하나의 모델에서 추론하는 방식이다.
- 만약 이 방법이 가능하다면, 1번 방식과 다르게 batch로 만드는 동안 생기는 overhead없이 추론할 수 있다.
실제로는 이 2가지 방법을 혼합해서 사용한다!
어떻게 2개를 혼용할 수 있을까?
이상적으로는 2가지 방법을 혼합해서 사용한다고 한다.
즉, 여러 개의 모델들이 올라가는데, 1개보다 많은 request를 묶어 batch형태로 만들어준 셈이다.
이 방법은 pytorch가 1차원의 tensor를 처리하기 위해서 최소 495MiB의 VRAM을 사용한다는 점을 주목해서 최대한의 효율을 끌어올리려는 목적이 담겨져 있다.
이 목적을 달성하기 위해 저자는 Multiprocess Concurrency 를 주목했다고 한다.
간단한 Use Case
위 그림을 설명하자면, Folder에 존재하는 이미지들을 Read(R)하여 이미지들 n 개의 Detector(D#)로 이미지 내의 object를 탐지하고 탐지된 object의 정보를 File(Fi)에 저장한다.
이 패턴을 우리는 Single Producer & Multiple Consumers 형태로 표현한다.
1. Message Pass 구조
Message를 전달하는 방법으로 우리는 Queue와 Event, 그리고 pipe를 사용할 수 있다.
- Queue
- SimpleQueue
- Queue
- JoinableQueue
python multiprocessing에서 지원하는 Queue는 SimpleQueue, Queue, JoinableQueue가 있다.
SimpleQueue는 non-blocking get() 기능이 없고, polling을 구현하기 위해서 .empty(), .full()을 사용할 수 있다.
simpleQueue와 다르게 put, get 등등 blocking에서 timeout을 걸고 사용할 수 있다.
Queue에서 제공하는 모든 기능을 다 갖고 있고, .join()을 사용해서 queue안에 있는 item들이 모두 처리될 때까지 blocking할 수 있다. 또한 task_done을 사용해서 queue에 대기 중인 작업이 완료되었음을 알리는 데 사용된다. (이 객체는 아직 안 사용해봐서 정확한 풀이보단 개념으로만 알고 있다는 점 참고)
Queue
과 JoinableQueue
는 "feeder" thread를 사용해서 pickling 과 queue 구현체 내부에 있는 pipe에 item을 넣기 때문에 만약 제대로 처리가 안된다면 에러가 발생할 것이다.- Event
Event란 boolean 변수를 setting, unsetting 시켜서 signal을 보내는 방법이다. Event object에 넘겨진 process는 .set(), .clear() 메소드를 사용해서 bool을 set, unset할 수 있다. 또한 .wait()과 .is_set() 을 사용해서 각각 blocking 과 polling 기능도 사용할 수 있다.
- Pipe
pipe는 두 개의 프로세스들 사이에서 message를 전송하는데 사용한다. n개의 producer, n개의 multiple consumer들이 하나의 pipe를 사용하면 안된다. pipe는 duplex, simplex 방법이 존재하는데 아래 코드 블럭에 소개된 것처럼 특징이 다르다.
recv, send = mp.Pipe(duplex=False) # recv only receives message, send only sends message. recv, send = mp.Pipe(duplex=True) # both recv, send can receive and send to each other.
queue는 pipe을 활용하고 있고 pipe와 queue모두 serializable 객체들을 전송한다. 즉, serialize가 안되는 객체를 transmit 할 수 없다. 만약 serializable 하지 않은 객체를 pipe나 queue에 넣어준다면
TypeError
를 만날 것이다.multiprocessing 방식이
spawn 방식이냐 fork 방식이냐
에 따라 TypeError 가 발생할 수 있다는 것을 현재 사내 프로젝트에서 겪고 있다.
spawn의 경우 독립적인 메모리 공간을 사용하기 때문에 _queue.simpleQueue object가 객체 공유가 안되어 TypeError가 발생하는 것으로 판단되어 해결 중이다.
해당 에러를 해결하게 되면 이와 관련된 내용을 간단하게라도 작성해 볼 예정이다.이제 produce와 consume의 기능을 살펴보자.
Producer function
해당 함수는 다음과 같은 과정을 맡는다.
- folder 내부에서 이미지들 읽기
- queue에 image들 넣기
제일 위 simple use case 그림에 있는 구조 중, D#의 detector들에게 들어가기 전에 queue까지의 과정을 맡는다고 보면 된다.
def read_images_into_q(images_path, queue, event, transform): image_list = list(Path(images_path).rglob(f"*.jpg")) while len(image_list) > 0: if queue.full(): time.sleep(0.05) continue else: image_path = image_list.pop() image = Image.open(image_path) image = transform(image) queue.put((image, image_path)) event.set() queue.join()
함수를 살펴보면 아래와 같이 정리할 수 있겠다.
- image_list는 image_path인 folder안에 존재하는 jpg 파일경로들을 리스트로 저장
- image_list안에 1개라도 이미지 파일 경로가 있다면 진행
- queue 안에 공간이 없다면 time.sleep(0.05)
- 만약 a가 아니라면 image_list안에 index(0)에 있는 파일 경로추출
- image open
- image 전처리
- queue에 전처리된 image와 image_path를 같이 저장
- event의 bool을 set()으로 지정 → 이미지 관련 처리가 다 끝났음을 명시
- 모든 item들에 대한 작업이 끝나기 전까지 queue.join()을 통해서 block시킨다.
Consumer function
소비자의 역할을 하는 부분은 4가지의 과정을 담당한다.
- Queue에서 이미지를 읽는다.
- Detector를 사용해서 이미지 내의 모델 추론을 진행한다.
- output을 만들기 위해 후처리를 진행
- file에 output 결과를 작성한다.
→ 먼저 1,2번 과정을 살펴보자
def detect_objects(queue, event, detector, device, lock, output_path): file = open(output_path.as_posix(), "a") detector.eval().to(device) while not (event.is_set() and queue.empty()): try: image, image_path = queue.get(block=True, timeout=0.1) except Empty: continue with torch.no_grad(): image = [image.to(device)] output = detector(image)[0] queue.task_done() handle_output(image_path, output, lock, file) # Points 3. & 4. file.close()
- file 경로를 받아서 열어놓는다.
- 만약 queue안에 task가 존재한다면 아래의 과정을 반복한다.
- queue에서 image와 image 경로를 get한다.
- image를 input을 넣어줄 device를 호출하고, detector에 넣어서 추론한다.
- output을 추출하여 handle_output함수에 넘겨준다.
- 모두 다 종료되면 file을 닫는다.
→그 다음 3,4번 과정을 확인해보자.
def handle_output(path, output, lock, file): filter_output(output) output_string = get_output_string(path, output) lock.acquire() file.write(output_string) file.flush() lock.release()
- image에 존재하는 class의 label을 confidence를 기반으로 filter시킨다. 이때 filter_output을 사용한다.
- 이후 file에 쓰기 위해 output_string을 통해 저장한다.
- lock으로
다른 프로세스에서 file에 접근하지 못하도록 막는다.
- 이후 write 후 flush를 사용해서 file 입출력 처리에 에러를 관리한다.
- 이후 file에 대한 lock을 release시킨다.
마지막으로 reader, detector process를 선언하고 실제로 호출하는 함수를 살펴보자.
caller function
def caller(device, images_path, output_path, detector_count, qsize): # Initialize sync structures queue = mp.JoinableQueue(qsize) event = mp.Event() lock = mp.Lock() # Initialize processes reader_process = mp.Process( target=read_images_into_q, args=(images_path, queue, event, transform) ) detector_processes = [mp.Process(target=detect_objects,\ args=(queue, event, get_detector(),\ device, lock, output_path))\ for i in range(detector_count)] # Starting processes reader_process.start() [dp.start() for dp in detector_processes] # Waiting for processes to complete [dp.join() for dp in detector_processes] reader_process.join() # Closing the queue queue.close()
- queue, lock을 multiprocessing에서 제공하는 객체들을 사용해서 선언한다.
- target과 args를 정리해서 reader와 detector 프로세스를 초기화해준다.
- reader쪽 process를 실행(단일 producer)
- detector쪽 process들을 start()를 진행
- join을 사용해서 해당 process들이 끝날때까지 기다린다.
- 마지막으로 queue에 대한 작업이 끝났다면 close()를 호출해서 정리한다.
추가 내용
- Message Passing & Synchronization :
중요
spawn 방식에서 mp.Process 생성자를 통해서 다른 프로세스로 전달되는 argument들은 serialize가 가능한 object여야 한다.
SimpleQueue를 사용하면 serialize하는데 추가 thread를 사용하지 않아서 좋지만, 제한된 기능만 제공한다는 것을 감안해야 한다.
→ 특히 queue.qsize()도 SimpleQueue에서는 불가능한데 모든 OS에서 제공을 안한다고 한다.
또한 ConnectionResetError, FileNotFoundError, BrokenPipe 등과 같은 에러는 feeder thread에서 발생하는데 이런 문제들이 발생한다면 joinalbeQueue를 사용하는 것도 권장된다.
결론
이번 글을 준비하면서 multiprocessing에 대한 내용들을 정리할 수 있었고, 이 multiprocessing을 잘 쓰는 것도 익숙해져야 겠다고 느꼈다. 다음 글에서는 vllm의 코드 구현부를 요약해보면서 vllm의 특징에 대해 정리해볼 예정이다.
참고 자료
- concurrent inference
- pytorch multiprocessing in window
Share article