-
0602- ROPE scaling사과나무심기 2024. 6. 2. 21:41
할로할로
여전히 나는 회사일을 ..하고있다고 하지만
뭐 사실은
평일에는 급급하게 회사 일을 쳐내고 있다면,
지식적인 부분을
조금씩 메꿔보는 주말이다.
LLAMA에서 사용하는 position embedding 인 ROPE
K,V에 대해서 cos, sin값으로 상대적 position embeding값을 구하는 방식
관련 코드
class LlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # For BC we register cos and sin cached self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() # Force float32 since bfloat16 loses precision on long contexts # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
transformers/src/transformers/models/llama/modeling_llama.py at main · huggingface/transformers
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. - huggingface/transformers
github.com
K,V 에 대해 position embedding을 더하는 코드
cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed
transformers/src/transformers/models/llama/modeling_llama.py at main · huggingface/transformers
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. - huggingface/transformers
github.com
'사과나무심기' 카테고리의 다른 글
0604 - 웰씽킹 (0) 2024.06.05 0603 - 당신은 이미 성공의 불씨를 얻었다 (1) 2024.06.04 0601- JNI란? (0) 2024.06.02 0531 (0) 2024.06.01 0530 - 나쁜 습관 세가지를 버린다 (0) 2024.05.31