This blog provides a overview of the different types of Device Copy operations within PyTorch, including Host to Device (H2D), Device to Host (D2H), and Device to Device (D2D) transfers.
0. Introduction
There are primarily two types of copy operations in PyTorch:
// aten/src/ATen/native/cuda/Copy.cu
staticvoidcopy_kernel_cuda(TensorIterator&iter,boolnon_blocking){AT_ASSERT(iter.ntensors()==2);Devicedst_device=iter.device(0);Devicesrc_device=iter.device(1);// Enable p2p access between devices.
boolp2p_enabled=maybe_enable_p2p_access(dst_device,src_device);if(copy_requires_temporaries(iter,p2p_enabled)){// ...
return;}// Copy on GPU (or between GPUs)
if(dst_device.is_cuda()&&src_device.is_cuda()){copy_device_to_device(iter,non_blocking,p2p_enabled);return;}// Copy between CPU and GPU
// ...
}
This process can generally be segmented into three distinct parts:
Copy utilizing temporaries
Copy on the GPU, or between GPUs if P2P (Peer-to-Peer, referring to direct memory access between one GPU and another) is enabled
Copy between the CPU and GPU, which do not require the use of temporaries
2.1 Copy with Temporaries
We don’t need to consider temporaries if:
Same Device Copy: No temporaries are needed.
Contiguous and Same Dtype Copy: No temporaries are needed.
Device-to-Device Copy with P2P Enabled: No temporaries are needed.
In other cases, copy_requires_temporaries returns True and we utilize temporary contiguous tensors to facilitate the copy.
// aten/src/ATen/native/cuda/Copy.cu
staticvoidcopy_kernel_cuda(TensorIterator&iter,boolnon_blocking){// ...
if(copy_requires_temporaries(iter,p2p_enabled)){auto&dst=iter.tensor(0);Tensordst_contig;Tensorsrc_contig;if(iter.device_type(0)==kCUDA||non_blocking){// if branch: In cuda or non_blocking is set
// uses dst if dst is contiguous, otherwise uses an empty contiguous tensor
dst_contig=dst.is_contiguous()?dst:at::empty_like(dst,LEGACY_CONTIGUOUS_MEMORY_FORMAT);// src is the same dtype and shape with dst, contiguous
src_contig=iter.tensor(1).to(iter.dtype(0)).expand_as(dst).contiguous();}else{// else branch: not in cuda and non_blocking is false
boolsame_type=iter.dtype(0)==iter.dtype(1);// uses dst if dst is contiguous and has the same dtype with src
dst_contig=(dst.is_contiguous()&&same_type)?dst:at::empty_like(dst,iter.dtype(1),LEGACY_CONTIGUOUS_MEMORY_FORMAT);// src has the shape with dst, contiguous
src_contig=iter.tensor(1).expand_as(dst).contiguous();}// ...
// perform a same-dtype copy on contiguous tensors
dst_contig.copy_(src_contig,non_blocking);// if necessary, copy back into dst
if(!dst_contig.is_same(dst)){TORCH_INTERNAL_ASSERT(dst_contig.device()==dst.device());dst.copy_(dst_contig,non_blocking);}return;}}staticboolcopy_requires_temporaries(TensorIterator&iter,boolp2p_enabled){// ...
if(dst_device==src_device){// same device, no temporaries needed
returnfalse;}boolsame_dtype=iter.dtype(0)==iter.dtype(1);if(same_dtype&&iter.is_contiguous()){// Contiguous same-dtype copies can always use `cudaMemcpyAsync`
returnfalse;}elseif(dst_device.is_cuda()&&src_device.is_cuda()){// Copies between GPUs can use the copy kernel if P2P is supported
return!p2p_enabled;}else{// The remaining cases require temporaries.
returntrue;}}
Here temporary tensors such as dst_contig and src_contig are created, followed by the reuse of copy_. Now that all inputs are contiguous, we can proceed to other branches below and complete the copy.
Finally, if necessary, the data is copied back into the dst tensor as outlined in the code.
2.2 Copy on GPU
When both tensors reside on the GPU, a D2D copy occurs.
// aten/src/ATen/native/cuda/Copy.cu
staticvoidcopy_kernel_cuda(TensorIterator&iter,boolnon_blocking){// ...
if(dst_device.is_cuda()&&src_device.is_cuda()){copy_device_to_device(iter,non_blocking,p2p_enabled);return;}// ...
}voidcopy_device_to_device(TensorIterator&iter,boolnon_blocking,boolp2p_enabled){int64_tnumel=iter.numel();// We can directly use memcpy if memcpy_eligible
boolsame_type=iter.dtype(0)==iter.dtype(1);boolsame_conj=iter.tensor(0).is_conj()==iter.tensor(1).is_conj();boolsame_neg=iter.tensor(0).is_neg()==iter.tensor(1).is_neg();boolmemcpy_eligible=same_type&&same_conj&&same_neg&&iter.is_contiguous();Devicedst_device=iter.device(0);Devicesrc_device=iter.device(1);// device guard is used to set/restore the current device context
CUDAGuarddevice_guard(src_device);CUDAStreamcopy_stream=getCurrentCUDAStream(src_device.index());if(src_device!=dst_device){// sync ...
}if(memcpy_eligible){// same dtype, contiguous, same conjugation and negation
void*dst=iter.data_ptr(0);void*src=iter.data_ptr(1);size_tsize=numel*iter.element_size(0);if(src!=dst||src_device!=dst_device){// Due to bizarre cuda driver intricacies, copies of
// cudaMallocAsynced memory between devices that aren't
// peer-to-peer-capable need "cudaMemcpyPeerAsync".
boolneeds_pool_specific_peer_access=CUDACachingAllocator::get()->needsPoolSpecificPeerAccess();boolneeds_MemcpyPeer=(src_device!=dst_device&&needs_pool_specific_peer_access&&!p2p_enabled);if(needs_MemcpyPeer){AT_CUDA_CHECK(cudaMemcpyPeerAsync(dst,dst_device.index(),src,src_device.index(),size,copy_stream));}else{AT_CUDA_CHECK(cudaMemcpyAsync(dst,src,size,cudaMemcpyDeviceToDevice,copy_stream));}}}else{if(same_neg){if(!same_conj){conj_kernel_cuda(iter);}else{direct_copy_kernel_cuda(iter);}}else{if(!same_conj){neg_conj_kernel_cuda(iter);}else{neg_kernel_cuda(iter);}}}if(src_device!=dst_device){// sync
}AT_CUDA_CHECK(cudaGetLastError());}
This process is divided into three main stages:
Block and wait for the dst tensor.(synchronization 1)
Perform the copy asynchronously.
Block and wait for the src tensor.(synchronization 2)
The logic for asynchronous copying is straightforward: If memcpy_eligible, we directly use cudaMemcpyPeerAsync or cudaMemcpyAsync.
If not, some other operations are performed. For example, in the case of direct_copy_kernel_cuda (for tensors have the same conj and neg conditions):
Here we employ gpu_kernel to launch a CUDA kernel using the data pointers calculated in TensorIterator and a simple lambda function return x;. This section will not be expanded upon here, but for those interested, more information can be found in my document on TensorIterator.
Regarding synchronization, there are two blocking points in the code, one at the src stream and one at the dst stream:
// aten/src/ATen/native/cuda/Copy.cu
voidcopy_device_to_device(TensorIterator&iter,boolnon_blocking,boolp2p_enabled){// ...
// device guard is used to set/restore the current device context
CUDAGuarddevice_guard(src_device);CUDAStreamcopy_stream=getCurrentCUDAStream(src_device.index());if(src_device!=dst_device){CUDAEventdst_ready;device_guard.set_device(dst_device);// record this event in dst's stream
dst_ready.record(getCurrentCUDAStream(dst_device.index()));device_guard.set_device(src_device);// block until all of the operations in dst before dst_ready event are done
// Note: won't block code in CPU here, only block for cuda stream
dst_ready.block(copy_stream);}// ... do copy async
if(src_device!=dst_device){CUDAEventsrc_ready;// record this event in src's stream
src_ready.record(copy_stream);// block until all of the operations in src are done
device_guard.set_device(dst_device);src_ready.block(getCurrentCUDAStream(dst_device.index()));}AT_CUDA_CHECK(cudaGetLastError());}
The first synchronization (at the src stream, waiting for the dst to be ready) ensures that all operations in dst stream preceding the dst_ready event are completed, setting the stage for the copy operation.
Then, the copy is performed asynchronously, with a task scheduled in the source stream.
Finally, synchronization occurs at the dst stream to ensure the completion of the copy operation.
With these synchronization, we can ensure the copy process is safe.
2.3 Copy between CPU and GPU (no temporaries)
This section addresses copying for contiguous tensors between hosts and GPUs.
// aten/src/ATen/native/cuda/Copy.cu
staticvoidcopy_kernel_cuda(TensorIterator&iter,boolnon_blocking){// ...
// Copy between CPU and GPU
cuda::OptionalCUDAGuarddevice_guard;cudaMemcpyKindkind;if(dst_device.is_cuda()&&src_device.is_cpu()){device_guard.set_device(dst_device);kind=cudaMemcpyHostToDevice;}elseif(dst_device.is_cpu()&&src_device.is_cuda()){device_guard.set_device(src_device);kind=cudaMemcpyDeviceToHost;}else{TORCH_INTERNAL_ASSERT(false,"unsupported devices in GPU copy_()");}void*dst=iter.data_ptr(0);void*src=iter.data_ptr(1);int64_tnbytes=iter.numel()*iter.element_size(0);CUDAStreamstream=getCurrentCUDAStream();if(non_blocking){AT_CUDA_CHECK(cudaMemcpyAsync(dst,src,nbytes,kind,stream));constauto&dst_tensor=iter.tensor(0);constauto&src_tensor=iter.tensor(1);constauto&host_tensor=(dst_device==kCPU?dst_tensor:src_tensor);auto*ptr=(dst_device==kCPU?dst:src);auto*ctx=host_tensor.storage().data_ptr().get_context();// record an event in current cuda stream based on the context and data ptr
// of the host tensor
CachingHostAllocator_recordEvent(ptr,ctx,stream);}else{at::cuda::memcpy_and_sync(dst,src,nbytes,kind,stream);}// ... neg and conj operations
}// torch/include/c10/cuda/CUDAFunctions.h
C10_CUDA_APIvoid__inline__memcpy_and_sync(void*dst,void*src,int64_tnbytes,cudaMemcpyKindkind,cudaStream_tstream){// ... gpu trace
C10_CUDA_CHECK(cudaMemcpyAsync(dst,src,nbytes,kind,stream));C10_CUDA_CHECK(cudaStreamSynchronize(stream));}
Here using TensorIterator, we obtain the pointers for the src and dst tensors. Depending on non_blocking, we either directly call cudaMemcpyAsync and record an event or opt for memcpy_and_sync.
Note: The recorded event pertains to CUDAHostAllocator (managing the memory of host tensors). Typically, a tensor’s memory block is not reused until the event is marked ready. For those interested in Memory Cache, further details can be found in aten/src/ATen/cuda/CachingHostAllocator.cpp.