Skip to content

LavenderDataLoader - Context Parallelism

For context parallelism, you can set the replication_pg parameter. It is a list of list of integers, where each inner list is a partition of the ranks. For example, [[0, 1], [2, 3]] means that rank 0 and 1 are in the first partition, and rank 2 and 3 are in the second partition. Within each partition, the ranks get the same samples.

dataloader = LavenderDataLoader(
dataset_id=dataset.id,
shardsets=[shardset.id],
rank=os.environ["RANK"],
replication_pg=[[0, 1], [2, 3]],
)