自注意力机制全解析:从原理到计算细节,一文尽览!

本文涉及的产品
NLP自然语言处理_高级版,每接口累计50万次
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
NLP自然语言处理_基础版,每接口每天50万次
简介: 自注意力机制(Self-Attention)最早可追溯至20世纪70年代的神经网络研究,但直到2017年Google Brain团队提出Transformer架构后才广泛应用于深度学习。它通过计算序列内部元素间的相关性,捕捉复杂依赖关系,并支持并行化训练,显著提升了处理长文本和序列数据的能力。相比传统的RNN、LSTM和GRU,自注意力机制在自然语言处理(NLP)、计算机视觉、语音识别及推荐系统等领域展现出卓越性能。其核心步骤包括生成查询(Q)、键(K)和值(V)向量,计算缩放点积注意力得分,应用Softmax归一化,以及加权求和生成输出。自注意力机制提高了模型的表达能力,带来了更精准的服务。

引言

自注意力机制的历史背景与发展

自注意力机制(Self-Attention)的概念最早可以追溯到20世纪70年代的神经网络研究,但直到近年来才在深度学习领域得到广泛关注和发展。现代意义上的自注意力机制首次出现在2017年的论文《Attention is All You Need》中,该论文由Google Brain团队提出,并引入了Transformer架构。这一创新迅速改变了自然语言处理(NLP)领域的格局。

在此之前,循环神经网络(RNN)及其变体长短期记忆网络(LSTM)和门控循环单元(GRU)是处理序列数据的主要方法。然而,这些模型存在一些固有的局限性,比如难以并行化训练、捕捉长距离依赖关系的能力有限等。此外,随着序列长度增加,RNN类模型的表现往往会下降。

为了解决这些问题,研究人员开始探索基于注意力机制的方法,它最初是为了改善编码器-解码器框架下的机器翻译任务而设计的。传统的注意力机制允许模型在生成输出时集中于输入序列中的某些特定部分,从而提高了性能。但是,这种外部注意力机制仍然依赖于编码器提供的上下文信息。

相比之下,自注意力机制不依赖于任何外部信息源,而是直接关注输入序列内部元素之间的相互作用。这不仅使得模型能够更有效地捕捉序列内部复杂的依赖关系,还极大地促进了模型的并行化训练,因为每个位置上的计算都可以独立进行。因此,自注意力机制成为构建高效且强大的序列建模工具的关键组件之一。

自注意力机制的概念

自注意力机制(Self-Attention),也称为内部注意力机制,是一种将单个序列的不同位置关联起来以计算同一序列的表示的注意力机制。这种机制允许模型在处理序列数据时,动态地调整对每个元素的关注程度,从而捕捉序列内部的复杂依赖关系。

自注意力机制的核心在于,它不依赖于外部信息,而是在序列内部元素之间进行信息的交互和整合。这意味着,对于序列中的每个元素,自注意力机制会计算该元素与序列中所有其他元素的相关性,生成一个加权的表示,其中权重反映了元素间的相互关系。

image.png

自注意力机制的计算过程可以被分解为几个关键步骤。

  1. 输入序列被映射到查询(Query)、键(Key)和值(Value)三个向量。

  2. 通过计算查询向量与所有键向量之间的点积来获得注意力得分。这些得分随后被缩放并经过Softmax函数进行归一化,以获得每个元素的注意力权重。

  3. 这些权重被用来对值向量进行加权求和,生成最终的输出序列。

自注意力机制在现代深度学习模型中的重要性

自注意力机制的重要性在于其灵活性和强大表达能力,特别是在处理长文本或其他类型的序列数据方面表现尤为突出。以下是几个关键点:

  • Transformer架构的核心:自从Transformer被提出以来,它已经在多个NLP基准测试中取得了顶尖的成绩,并成为了当前最先进的预训练语言模型的基础,如BERT、GPT系列等。这些模型都依赖于多层堆叠的自注意力机制来实现卓越的效果。

  • 并行化优势:与传统RNN不同的是,自注意力机制允许对整个序列进行并行处理,而不是按顺序逐个处理时间步。这一特性大大加快了训练速度,尤其是在大规模语料库上训练时显得尤为重要。

  • 捕捉全局依赖性:通过让每个元素都能“看到”整个序列中的所有其他元素,自注意力机制能够在单一层内建立起非常广泛且深入的上下文联系。这对于理解复杂句子结构或文档级别的语义关系至关重要。

  • 跨领域应用:除了NLP之外,自注意力机制也被成功应用于计算机视觉、语音识别等多个领域。例如,在图像分类任务中,它可以用来捕捉图片内的空间依赖关系;而在视频分析中,则有助于理解时间维度上的动态变化。

Q、K、V的生成

在自注意力(Self-Attention)机制中,查询(Query,简称Q)、键(Key,简称K)和值(Value,简称V)是三个核心的概念,它们共同参与计算以生成序列的加权表示。

查询(Query,Q)

查询向量Q代表了当前元素在序列中的作用,它用于“询问”序列中的其他元素以获取相关信息。在自注意力机制中,每个元素都会生成一个对应的查询向量,该向量用于与序列中的所有键向量进行比较,以确定每个元素的重要性或相关性。

键(Key,K)

键向量K包含了序列中每个元素的特征信息,这些信息将用于与查询向量进行匹配。键向量的主要作用是提供一种机制,使得模型能够识别和比较序列中不同元素之间的关系。在自注意力中,每个元素都会有一个对应的键向量,它与查询向量一起决定了注意力分数。

值(Value,V)

值向量V包含了序列中每个元素的实际信息或特征,这些信息将根据注意力分数被加权求和,以生成最终的输出。值向量代表了序列中每个元素的具体内容,它们是模型最终用于生成输出的原始数据。

image.png

在自注意力机制中,输入序列的每个元素首先被映射到三个向量:查询(Q)、键(K)和值(V)。这一过程通常通过与三个权重矩阵的线性变换实现。具体来说,输入序列X与权重矩阵W^Q、W^K和W^V相乘,得到Q、K和V:

image.png

其中,X是输入序列,W^Q、W^K和W^V是可学习的权重矩阵。这些矩阵的维度通常是(序列长度,特征维度)乘以(特征维度,Q/K/V维度)。Q、K和V的维度是(序列长度,Q/K/V维度)。

这三个变换可以看作是对原始输入的一种重新编码,目的是从不同角度提取信息,以便后续计算注意力分数时能够更有效地捕捉到元素间的相关性。

缩放点积计算注意力得分

自注意力机制中,查询向量Q与所有键向量K之间的点积被用来计算注意力得分。为了避免点积结果过大导致梯度问题,引入了一个缩放因子1/√dk,其中dk是键向量的维度。缩放后的注意力得分计算如下:

image.png
image.png

这个操作生成了一个注意力得分矩阵,其中每个元素代表对应元素对之间的相似度。

Softmax 归一化

为了将注意力得分转换为权重,应用Softmax函数进行归一化。Softmax确保所有输出权重的和为1,从而使得模型可以学习到每个元素对的重要性:

image.png

Softmax函数定义为:

image.png
image.png

其中,xi是注意力得分矩阵中的元素。这意味着对于每个位置 i,在计算完 softmax 后,我们会获得该位置与其他所有位置的相关性的“概率”。

权求和生成输出

最后,归一化的注意力权重被用来对值向量V进行加权求和,生成最终的输出序列。输出序列的每个元素是所有值向量的一个加权和,权重由对应的注意力权重决定:

image.png

这一步骤有效地整合了序列内部的信息,使得每个元素的输出表示包含了整个序列的上下文信息。

完整的计算流程

f3a43e21-ac19-4ce8-9085-9b03de505c51.gif

接下来,会一步步拆解,让你更清楚地理解这个过程。

首先,我们看输入部分:

image.png

然后,我们会用初始化的权重来计算key、value和query:

image.png

接下来,我们会为第一个输入计算attention score:

image.png

然后,对attention score执行softmax操作:

image.png

这一步之后,我们会将每个输入的softmaxed score和对应的value相乘,得到3个weighted value:

image.png

最后,把上一步的weighted value加起来(对应元素相加),就得到了输出。同样的步骤,我们也会对input#2和input#3执行,得到另外两个输出。

image.png

最后,这段代码展示了如何定义一个SelfAttention类,并在其中实现自注意力机制的核心步骤。

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        # Ensure the embedding size is divisible by the number of heads
        assert (self.head_dim * heads == embed_size), "Embedding size must be divisible by heads"

        # Define linear transformations for queries, keys, and values
        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N = x.shape[0]  # Batch size
        length = x.shape[1]  # Sequence length

        # Pass inputs through linear layers for queries, keys, and values
        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)

        # Split into multiple heads
        values = values.view(N, length, self.heads, self.head_dim)
        keys = keys.view(N, length, self.heads, self.head_dim)
        queries = queries.view(N, length, self.heads, self.head_dim)

        # Permute dimensions to put heads second
        values = values.permute(0, 2, 1, 3)  # (N, heads, length, head_dim)
        keys = keys.permute(0, 2, 1, 3)
        queries = queries.permute(0, 2, 1, 3)

        # Compute scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # (N, heads, query_len, key_len)
        attention = F.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)  # Softmax on the last dimension

        # Weighted sum
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, length, self.heads * self.head_dim
        )  # (N, length, embed_size)

        return self.fc_out(out)

# Example usage
embed_size = 256  # Embedding dimension
heads = 8  # Number of attention heads
x = torch.rand(64, 10, embed_size)  # Simulate input data

self_attention = SelfAttention(embed_size, heads)
output = self_attention(x)

print(output.shape)  # Should print: torch.Size([64, 10, 256])

自注意力机制的扩展应用场景

自然语言处理(NLP)中的机器翻译

案例描述:使用Transformer架构进行英德双语翻译。Transformer通过多层堆叠的自注意力机制,能够捕捉源语言句子中各个单词之间的复杂关系,从而生成更加准确的目标语言句子。

import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

# 初始化BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained('bert-base-multilingual-cased')

def translate_sentence(sentence, model, tokenizer):
    # 对输入句子进行编码
    inputs = tokenizer(sentence, return_tensors="pt")

    # 获取BERT模型的输出
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state

    # 这里可以加入额外的解码逻辑,例如使用一个预训练好的翻译模型
    # 简单起见,我们直接返回最后一层隐藏状态作为示例
    return last_hidden_states

# 示例翻译
sentence = "Hello, how are you?"
translated = translate_sentence(sentence, model, tokenizer)
print(translated.shape)  # 输出形状应为 (batch_size, sequence_length, hidden_dim)

推荐系统中的个性化推荐

在推荐系统中,个性化推荐的目标是为每个用户提供最符合其兴趣和需求的内容或商品。为了实现这一点,模型需要能够捕捉用户行为模式,并根据这些模式进行预测。自注意力机制(Self-Attention)因其强大的建模能力,在这一领域得到了广泛应用。下面我们将进一步细化基于自注意力机制的个性化推荐系统的构建过程,并提供更加详细的代码示例。

模型选择:SASRec (Self-Attentive Sequential Recommendation)

SASRec 是一种专门为序列化推荐设计的深度学习模型,它利用了自注意力机制来捕捉用户历史行为之间的复杂依赖关系。该模型可以有效地处理长序列数据,并且能够在不牺牲计算效率的情况下提高推荐质量。

数据准备

在构建个性化推荐系统之前,我们需要准备好用户的历史行为数据。通常情况下,这些数据包括用户ID、物品ID以及交互时间戳等信息。对于SASRec模型来说,还需要对物品进行编码,以便输入到模型中。

import pandas as pd
from sklearn.preprocessing import LabelEncoder

# 假设我们有一个包含用户行为的数据框 df
df = pd.read_csv('user_behavior.csv')

# 对用户和物品ID进行编码
user_encoder = LabelEncoder()
item_encoder = LabelEncoder()

df['user_id'] = user_encoder.fit_transform(df['original_user_id'])
df['item_id'] = item_encoder.fit_transform(df['original_item_id'])

# 将数据按用户分组,并按照时间排序
user_sequences = df.groupby('user_id')['item_id'].apply(list).reset_index(name='sequence')

构建SASRec模型

接下来,我们将定义SASRec模型结构。这里使用 TensorFlow 和 Keras API 来实现。需要注意的是,实际应用中可能需要安装额外的库,如 tfrs 或者其他特定于推荐系统的框架。

import tensorflow as tf
from tensorflow.keras.layers import Embedding, LayerNormalization, Dense
from tensorflow.keras.models import Model
import numpy as np

class SASRec(Model):
    def __init__(self, num_users, num_items, embedding_dim, max_len, num_heads=1, num_blocks=2):
        super(SASRec, self).__init__()
        self.user_embedding = Embedding(input_dim=num_users + 1, output_dim=embedding_dim)
        self.item_embedding = Embedding(input_dim=num_items + 1, output_dim=embedding_dim)
        self.positional_encoding = self.get_positional_encoding(max_len, embedding_dim)
        self.attention_layers = [tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim) for _ in range(num_blocks)]
        self.layer_norms = [LayerNormalization() for _ in range(num_blocks)]
        self.dense = Dense(num_items, activation='softmax')

    def get_positional_encoding(self, seq_len, d_model):
        # 创建位置编码矩阵
        position = np.arange(seq_len)[:, np.newaxis]
        div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        pe = np.zeros((seq_len, d_model))
        pe[:, 0::2] = np.sin(position * div_term)
        pe[:, 1::2] = np.cos(position * div_term)
        return pe[np.newaxis, ...]

    def call(self, inputs):
        user_ids, sequences = inputs
        seq_emb = self.item_embedding(sequences) + self.positional_encoding[:, :tf.shape(sequences)[1], :]
        mask = tf.cast(tf.not_equal(sequences, 0), dtype=tf.float32)

        for i in range(len(self.attention_layers)):
            seq_emb = self.layer_norms[i](seq_emb + self.attention_layers[i](seq_emb, seq_emb, attention_mask=mask))

        logits = self.dense(seq_emb)
        return logits

# 初始化SASRec模型
num_users = len(user_encoder.classes_)
num_items = len(item_encoder.classes_)
embedding_dim = 128
max_len = 50  # 用户行为序列的最大长度

model = SASRec(num_users, num_items, embedding_dim, max_len)

训练模型

为了训练模型,我们需要定义损失函数和优化器,并编写一个适合推荐任务的数据生成器。这里简化了训练过程,实际应用中应该根据具体情况进行调整。

from tensorflow.keras.optimizers import Adam

# 编写简单的负采样函数以生成负样本
def negative_sampling(pos_seq, num_items, num_neg_samples=5):
    neg_samples = []
    for pos_item in pos_seq:
        neg_items = np.random.choice(num_items, size=num_neg_samples, replace=False)
        neg_samples.append(neg_items)
    return np.array(neg_samples)

# 定义损失函数(交叉熵)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# 选择优化器
optimizer = Adam(learning_rate=0.001)

# 训练模型
for epoch in range(epochs):
    for batch in data_generator():
        with tf.GradientTape() as tape:
            user_ids, sequences, labels = batch
            predictions = model([user_ids, sequences])
            loss = loss_fn(labels, predictions)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.numpy()}')

推荐预测

一旦模型训练完成,就可以用来为新用户提供推荐服务。通过给定用户的最新行为序列,模型可以预测下一个最有可能点击的商品。

def recommend_items(model, user_id, user_history, top_k=10):
    # 对用户行为序列进行填充以匹配最大长度
    sequence = np.pad(user_history[-max_len:], (max_len - len(user_history[-max_len:]), 0), 'constant')

    # 获取预测结果
    predictions = model.predict([[user_id], [sequence]])
    sorted_indices = np.argsort(predictions[0])[-top_k:]

    # 返回解码后的商品ID
    recommended_items = item_encoder.inverse_transform(sorted_indices)
    return recommended_items.tolist()

# 示例推荐
user_id = 1234
user_history = [234, 567, 890]  # 用户的历史行为记录
recommendations = recommend_items(model, user_id, user_history)
print(recommendations)  # 输出推荐商品ID

这种类型的模型特别适用于那些需要理解用户长期偏好变化的应用场景,例如电子商务平台上的商品推荐、社交媒体中的内容推荐等。

总结

通过上述详细探讨,我们不仅深入了解了自注意力机制的基本原理及其在自然语言处理、计算机视觉、语音处理以及推荐系统等多个领域的广泛应用,还结合具体案例和代码示例展示了如何将这一强大工具付诸实践。自注意力机制凭借其捕捉复杂依赖关系的能力、并行化处理的优势以及对长序列数据的有效建模,已经成为现代深度学习不可或缺的一部分。


参考文献:

Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008).

Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. (2018). BERT: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805.

Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., ... & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929.

Wu, Y., Al-Shedivat, M., Ghasemipour, S., Wang, Y., Abbeel, P., Schwing, A. G., & Finn, C. (2020). Pay less attention with lightweight and dynamic convolutions. International Conference on Learning Representations.

相关文章
|
1月前
|
存储 缓存 算法
HashMap深度解析:从原理到实战
HashMap,作为Java集合框架中的一个核心组件,以其高效的键值对存储和检索机制,在软件开发中扮演着举足轻重的角色。作为一名资深的AI工程师,深入理解HashMap的原理、历史、业务场景以及实战应用,对于提升数据处理和算法实现的效率至关重要。本文将通过手绘结构图、流程图,结合Java代码示例,全方位解析HashMap,帮助读者从理论到实践全面掌握这一关键技术。
99 14
|
2月前
|
运维 持续交付 云计算
深入解析云计算中的微服务架构:原理、优势与实践
深入解析云计算中的微服务架构:原理、优势与实践
106 3
|
24天前
|
存储 物联网 大数据
探索便宜云服务器 Flink 物化表:原理、优势与应用场景全解析
便宜云服务器Flink的物化表是流批一体化平台中的关键特性,支持低延迟实时更新、灵活查询性能、无缝流批处理和高容错性。它广泛应用于电商、物联网和金融等领域,助力企业高效处理实时数据,提升业务决策能力。实践案例表明,物化表显著提高了交易欺诈损失率的控制和信贷审批效率,推动企业在数字化转型中取得竞争优势。
97 16
|
1月前
|
网络协议 安全 网络安全
探索网络模型与协议:从OSI到HTTPs的原理解析
OSI七层网络模型和TCP/IP四层模型是理解和设计计算机网络的框架。OSI模型包括物理层、数据链路层、网络层、传输层、会话层、表示层和应用层,而TCP/IP模型则简化为链路层、网络层、传输层和 HTTPS协议基于HTTP并通过TLS/SSL加密数据,确保安全传输。其连接过程涉及TCP三次握手、SSL证书验证、对称密钥交换等步骤,以保障通信的安全性和完整性。数字信封技术使用非对称加密和数字证书确保数据的机密性和身份认证。 浏览器通过Https访问网站的过程包括输入网址、DNS解析、建立TCP连接、发送HTTPS请求、接收响应、验证证书和解析网页内容等步骤,确保用户与服务器之间的安全通信。
117 3
|
2月前
|
存储 缓存 监控
后端开发中的缓存机制:深度解析与最佳实践####
本文深入探讨了后端开发中不可或缺的一环——缓存机制,旨在为读者提供一份详尽的指南,涵盖缓存的基本原理、常见类型(如内存缓存、磁盘缓存、分布式缓存等)、主流技术选型(Redis、Memcached、Ehcache等),以及在实际项目中如何根据业务需求设计并实施高效的缓存策略。不同于常规摘要的概述性质,本摘要直接点明文章将围绕“深度解析”与“最佳实践”两大核心展开,既适合初学者构建基础认知框架,也为有经验的开发者提供优化建议与实战技巧。 ####
|
1月前
|
PHP 开发者 UED
PHP中的异常处理机制解析####
本文深入探讨了PHP中的异常处理机制,通过实例解析try-catch语句的用法,并对比传统错误处理方式,揭示其在提升代码健壮性与可维护性方面的优势。文章还简要介绍了自定义异常类的创建及其应用场景,为开发者提供实用的技术参考。 ####
|
2月前
|
缓存 NoSQL Java
千万级电商线上无阻塞双buffer缓冲优化ID生成机制深度解析
【11月更文挑战第30天】在千万级电商系统中,ID生成机制是核心基础设施之一。一个高效、可靠的ID生成系统对于保障系统的稳定性和性能至关重要。本文将深入探讨一种在千万级电商线上广泛应用的ID生成机制——无阻塞双buffer缓冲优化方案。本文从概述、功能点、背景、业务点、底层原理等多个维度进行解析,并通过Java语言实现多个示例,指出各自实践的优缺点。希望给需要的同学提供一些参考。
61 8
|
1月前
|
Java 数据库连接 开发者
Java中的异常处理机制:深入解析与最佳实践####
本文旨在为Java开发者提供一份关于异常处理机制的全面指南,从基础概念到高级技巧,涵盖try-catch结构、自定义异常、异常链分析以及最佳实践策略。不同于传统的摘要概述,本文将以一个实际项目案例为线索,逐步揭示如何高效地管理运行时错误,提升代码的健壮性和可维护性。通过对比常见误区与优化方案,读者将获得编写更加健壮Java应用程序的实用知识。 --- ####
|
2月前
|
存储 供应链 算法
深入解析区块链技术的核心原理与应用前景
深入解析区块链技术的核心原理与应用前景
84 0
|
2月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
117 2

热门文章

最新文章

推荐镜像

更多
http://www.vxiaotou.com