Skip to content

vllm.v1.worker.gpu.penalties

_penalties_kernel

_penalties_kernel(
    logits_ptr,
    logits_stride,
    repetition_penalty_ptr,
    frequency_penalty_ptr,
    presence_penalty_ptr,
    idx_mapping_ptr,
    prompt_bin_counts_ptr,
    prompt_bin_counts_stride,
    output_bin_counts_ptr,
    output_bin_counts_stride,
    vocab_size,
    BLOCK_SIZE: constexpr,
)
Source code in vllm/v1/worker/gpu/penalties.py
@triton.jit
def _penalties_kernel(
    logits_ptr,
    logits_stride,
    repetition_penalty_ptr,
    frequency_penalty_ptr,
    presence_penalty_ptr,
    idx_mapping_ptr,
    prompt_bin_counts_ptr,
    prompt_bin_counts_stride,
    output_bin_counts_ptr,
    output_bin_counts_stride,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
    freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
    pres_penalty = tl.load(presence_penalty_ptr + batch_idx)

    use_rep_penalty = rep_penalty != 1.0
    use_freq_penalty = freq_penalty != 0.0
    use_pres_penalty = pres_penalty != 0.0
    if not (use_rep_penalty or use_freq_penalty or use_pres_penalty):
        # No penalties to apply. Early return.
        return

    block_idx = tl.program_id(1)
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = block < vocab_size
    logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
    logits = logits.to(tl.float32)

    req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
    output_bin_counts = tl.load(
        output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
        mask=mask,
    )

    # Apply repetition penalties.
    if use_rep_penalty:
        prompt_bin_counts = tl.load(
            prompt_bin_counts_ptr + req_state_idx * prompt_bin_counts_stride + block,
            mask=mask,
        )
        # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
        scale = tl.where((prompt_bin_counts + output_bin_counts) > 0, rep_penalty, 1.0)
        # If logits are positive, divide by penalty, otherwise multiply by penalty.
        scale = tl.where(logits > 0, 1.0 / scale, scale)
        logits *= scale

    # Apply frequency penalties.
    logits -= freq_penalty * output_bin_counts
    # Apply presence penalties.
    logits -= pres_penalty * (output_bin_counts > 0)
    # Store back to logits.
    tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)

apply_penalties

apply_penalties(
    logits: Tensor, sampling_metadata: SamplingMetadata
) -> None
Source code in vllm/v1/worker/gpu/penalties.py
def apply_penalties(logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> None:
    num_reqs, vocab_size = logits.shape
    BLOCK_SIZE = 8192
    num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
    _penalties_kernel[(num_reqs, num_blocks)](
        logits,
        logits.stride(0),
        sampling_metadata.repetition_penalty,
        sampling_metadata.frequency_penalty,
        sampling_metadata.presence_penalty,
        sampling_metadata.idx_mapping,
        sampling_metadata.prompt_bin_counts,
        sampling_metadata.prompt_bin_counts.stride(0),
        sampling_metadata.output_bin_counts,
        sampling_metadata.output_bin_counts.stride(0),
        vocab_size,
        BLOCK_SIZE=BLOCK_SIZE,
    )