pytorch 模型训练时多卡负载不均衡(GPU的0卡显存过高)解决办法(简单有效)

pytorch 模型训练时多卡负载不均衡(GPU的0卡显存过高)解决办法(简单有效)

本文主要解决pytorch在进行模型训练时出现GPU的0卡占用显存比其他卡要多的问题。 如下图所示:本机GPU卡为TITAN RTX,显存24220M,batch_size = 9,用了三张卡。第0卡显存占用24207M,这时仅仅是刚开始运行,数据只是少量的移到显卡上,如果数据在多点,0卡的显存肯定撑爆。出现0卡显存更高的原因:网络在反向传播的时候,计算loss的梯度默认都在0卡上计算。因此会比其他显卡多用一些显存,具体多用多少,主要还要看网络的结构。 因此,为了防止训练由于 out of memory 而中断。比较笨的办法是将batch_size设为6,即每张卡放2条数据。 batch_size = 6时,其他不变,如下图所示 有没有发现问题?显存只用了1,2卡的显存只用了16G不到。就因为0卡可能会超那么一点点显存,而牺牲了batch_size。 那么没有更优雅的方法呢?答案是肯定的。那就是借用下transformer-xl中用到的 BalancedDataParallel类。代码如下(代码出处):

import torch

from torch.nn.parallel.data_parallel import DataParallel

from torch.nn.parallel.parallel_apply import parallel_apply

from torch.nn.parallel._functions import Scatter

def scatter(inputs, target_gpus, chunk_sizes, dim=0):

r"""

Slices tensors into approximately equal chunks and

distributes them across given GPUs. Duplicates

references to objects that are not tensors.

"""

def scatter_map(obj):

if isinstance(obj, torch.Tensor):

try:

return Scatter.apply(target_gpus, chunk_sizes, dim, obj)

except Exception:

print('obj', obj.size())

print('dim', dim)

print('chunk_sizes', chunk_sizes)

quit()

if isinstance(obj, tuple) and len(obj) > 0:

return list(zip(*map(scatter_map, obj)))

if isinstance(obj, list) and len(obj) > 0:

return list(map(list, zip(*map(scatter_map, obj))))

if isinstance(obj, dict) and len(obj) > 0:

return list(map(type(obj), zip(*map(scatter_map, obj.items()))))

return [obj for targets in target_gpus]

# After scatter_map is called, a scatter_map cell will exist. This cell

# has a reference to the actual function scatter_map, which has references

# to a closure that has a reference to the scatter_map cell (because the

# fn is recursive). To avoid this reference cycle, we set the function to

# None, clearing the cell

try:

return scatter_map(inputs)

finally:

scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):

"""Scatter with support for kwargs dictionary"""

inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []

kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []

if len(inputs) < len(kwargs):

inputs.extend([() for _ in range(len(kwargs) - len(inputs))])

elif len(kwargs) < len(inputs):

kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])

inputs = tuple(inputs)

kwargs = tuple(kwargs)

return inputs, kwargs

class BalancedDataParallel(DataParallel):

def __init__(self, gpu0_bsz, *args, **kwargs):

self.gpu0_bsz = gpu0_bsz

super().__init__(*args, **kwargs)

def forward(self, *inputs, **kwargs):

if not self.device_ids:

return self.module(*inputs, **kwargs)

if self.gpu0_bsz == 0:

device_ids = self.device_ids[1:]

else:

device_ids = self.device_ids

inputs, kwargs = self.scatter(inputs, kwargs, device_ids)

if len(self.device_ids) == 1:

return self.module(*inputs[0], **kwargs[0])

replicas = self.replicate(self.module, self.device_ids)

if self.gpu0_bsz == 0:

replicas = replicas[1:]

outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)

return self.gather(outputs, self.output_device)

def parallel_apply(self, replicas, device_ids, inputs, kwargs):

return parallel_apply(replicas, inputs, kwargs, device_ids)

def scatter(self, inputs, kwargs, device_ids):

bsz = inputs[0].size(self.dim)

num_dev = len(self.device_ids)

gpu0_bsz = self.gpu0_bsz

bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)

if gpu0_bsz < bsz_unit:

chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)

delta = bsz - sum(chunk_sizes)

for i in range(delta):

chunk_sizes[i + 1] += 1

if gpu0_bsz == 0:

chunk_sizes = chunk_sizes[1:]

else:

return super().scatter(inputs, kwargs, device_ids)

return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

从代码中可以看到,BalancedDataParallel继承了 torch.nn.DataParallel,之后通过自定义0卡batch_size的大小gpu0_bsz,即让0卡少一点数据。均衡0卡和其他卡的显存占用。调用代码如下:

import BalancedDataParallel

if n_gpu > 1:

model = BalancedDataParallel(2, model, dim=0).to(device)

# model = torch.nn.DataParallel(model)

gpu0_bsz:GPU的0卡batch_size; model:模型; dim:batch所在维度

因此,我们不妨将刚才的batch_size设为8,即gpu0_bsz=2试试,结果如下: 成功的将batch_size从6调整到了8,因为0卡少放了一个batch,因此,会比其他的卡少。但是牺牲一张卡的显存,换取其他卡的显存,最终提高了batch_size,还是可取得。特别是当卡数目比较多的时候,这种方法的优势就更明显了。

🌸 相关推荐 🌸

您所访问的页面不存在
365bet中文体育在线

您所访问的页面不存在

📅 06-28 👀 687
海豚是什么颜色的,有蓝色的吗
365bet中文体育在线

海豚是什么颜色的,有蓝色的吗

📅 07-07 👀 4684
【2025机箱风扇散热器排行榜】机箱风扇散热器哪款好