LavenderDataLoader - Context Parallelism
For context parallelism, you can set the replication_pg
parameter.
It is a list of list of integers, where each inner list is a partition of the ranks.
For example, [[0, 1], [2, 3]]
means that rank 0 and 1 are in the first partition,
and rank 2 and 3 are in the second partition.
Within each partition, the ranks get the same samples.
dataloader = LavenderDataLoader( dataset_id=dataset.id, shardsets=[shardset.id], rank=os.environ["RANK"], replication_pg=[[0, 1], [2, 3]],)