吴忠躺衫网络科技有限公司

電子發(fā)燒友App

硬聲App

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示
創(chuàng)作
電子發(fā)燒友網(wǎng)>電子資料下載>電子資料>PyTorch教程15.4之預(yù)訓(xùn)練word2vec

PyTorch教程15.4之預(yù)訓(xùn)練word2vec

2023-06-05 | pdf | 0.14 MB | 次下載 | 免費(fèi)

資料介紹

我們繼續(xù)實(shí)現(xiàn) 15.1 節(jié)中定義的 skip-gram 模型。然后我們將在 PTB 數(shù)據(jù)集上使用負(fù)采樣來(lái)預(yù)訓(xùn)練 word2vec。首先,讓我們通過(guò)調(diào)用函數(shù)來(lái)獲取數(shù)據(jù)迭代器和這個(gè)數(shù)據(jù)集的詞匯表 ,這在第 15.3 節(jié)d2l.load_data_ptb中有描述

import math
import torch
from torch import nn
from d2l import torch as d2l

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                   num_noise_words)
Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip...
import math
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                   num_noise_words)

15.4.1。Skip-Gram 模型

我們通過(guò)使用嵌入層和批量矩陣乘法來(lái)實(shí)現(xiàn) skip-gram 模型。首先,讓我們回顧一下嵌入層是如何工作的。

15.4.1.1。嵌入層

如第 10.7 節(jié)所述,嵌入層將標(biāo)記的索引映射到其特征向量。該層的權(quán)重是一個(gè)矩陣,其行數(shù)等于字典大小 ( input_dim),列數(shù)等于每個(gè)標(biāo)記的向量維數(shù) ( output_dim)。一個(gè)詞嵌入模型訓(xùn)練好之后,這個(gè)權(quán)重就是我們所需要的。

embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
print(f'Parameter embedding_weight ({embed.weight.shape}, '
   f'dtype={embed.weight.dtype})')
Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)
embed = nn.Embedding(input_dim=20, output_dim=4)
embed.initialize()
embed.weight
Parameter embedding0_weight (shape=(20, 4), dtype=float32)

嵌入層的輸入是標(biāo)記(單詞)的索引。對(duì)于任何令牌索引i,它的向量表示可以從ith嵌入層中權(quán)重矩陣的行。由于向量維度 ( output_dim) 設(shè)置為 4,因此嵌入層返回形狀為 (2, 3, 4) 的向量,用于形狀為 (2, 3) 的標(biāo)記索引的小批量。

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
embed(x)
tensor([[[-0.6501, 1.3547, 0.7968, 0.3916],
     [ 0.4739, -0.0944, 1.2308, 0.6457],
     [ 0.4539, 1.5194, 0.4377, -1.5122]],

    [[-0.7032, -0.1213, 0.2657, -0.6797],
     [ 0.2930, -0.6564, 0.8960, -0.5637],
     [-0.1815, 0.9487, 0.8482, 0.5486]]], grad_fn=<EmbeddingBackward0>)
x = np.array([[1, 2, 3], [4, 5, 6]])
embed(x)
array([[[ 0.01438687, 0.05011239, 0.00628365, 0.04861524],
    [-0.01068833, 0.01729892, 0.02042518, -0.01618656],
    [-0.00873779, -0.02834515, 0.05484822, -0.06206018]],

    [[ 0.06491279, -0.03182812, -0.01631819, -0.00312688],
    [ 0.0408415 , 0.04370362, 0.00404529, -0.0028032 ],
    [ 0.00952624, -0.01501013, 0.05958354, 0.04705103]]])

15.4.1.2。定義前向傳播

在正向傳播中,skip-gram 模型的輸入包括形狀為(批大小,1)的中心詞索引和 形狀為(批大小,center的連接上下文和噪聲詞索引,其中定義在 第 15.3.5 節(jié). 這兩個(gè)變量首先通過(guò)嵌入層從標(biāo)記索引轉(zhuǎn)換為向量,然后它們的批量矩陣乘法(在第 11.3.2.2 節(jié)中描述)返回形狀為(批量大小,1, )的輸出 。輸出中的每個(gè)元素都是中心詞向量與上下文或噪聲詞向量的點(diǎn)積。contexts_and_negativesmax_lenmax_lenmax_len

def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
  v = embed_v(center)
  u = embed_u(contexts_and_negatives)
  pred = torch.bmm(v, u.permute(0, 2, 1))
  return pred
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
  v = embed_v(center)
  u = embed_u(contexts_and_negatives)
  pred = npx.batch_dot(v, u.swapaxes(1, 2))
  return pred

skip_gram讓我們?yōu)橐恍┦纠斎?/font>打印此函數(shù)的輸出形狀。

skip_gram(torch.ones((2, 1), dtype=torch.long),
     torch.ones((2, 4), dtype=torch.long), embed, embed).shape
torch.Size([2, 1, 4])
skip_gram(np.ones((2, 1)), np.ones((2, 4)), embed, embed).shape
(2, 1, 4)

15.4.2。訓(xùn)練

在用負(fù)采樣訓(xùn)練skip-gram模型之前,我們先定義它的損失函數(shù)。

15.4.2.1。二元交叉熵?fù)p失

根據(jù)15.2.1節(jié)負(fù)采樣損失函數(shù)的定義,我們將使用二元交叉熵?fù)p失。

class SigmoidBCELoss(nn.Module):
  # Binary cross-entropy loss with masking
  def __init__(self):
    super().__init__()

  def forward(self, inputs, target, mask=None):
    out = nn.functional.binary_cross_entropy_with_logits(
      inputs, target, weight=mask, reduction="none")
    return out.mean(dim=1)

loss = SigmoidBCELoss()
loss = gluon.loss.SigmoidBCELoss()

回想我們?cè)诘?15.3.5 節(jié)中對(duì)掩碼變量和標(biāo)簽變量的描述 下面計(jì)算給定變量的二元交叉熵?fù)p失。

pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
tensor([0.9352, 1.8462])
pred = np.array([[1.1, -2.2, 3.3, -4.4]] * 2)
label = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = np.array([[1,
下載該資料的人也在下載 下載該資料的人還在閱讀
更多 >

評(píng)論

查看更多

下載排行

本周

  1. 1A7159和A7139射頻芯片的資料免費(fèi)下載
  2. 0.20 MB   |  55次下載  |  5 積分
  3. 2PIC12F629/675 數(shù)據(jù)手冊(cè)免費(fèi)下載
  4. 2.38 MB   |  36次下載  |  5 積分
  5. 3PIC16F716 數(shù)據(jù)手冊(cè)免費(fèi)下載
  6. 2.35 MB   |  18次下載  |  5 積分
  7. 4dsPIC33EDV64MC205電機(jī)控制開(kāi)發(fā)板用戶(hù)指南
  8. 5.78MB   |  8次下載  |  免費(fèi)
  9. 5STC15系列常用寄存器匯總免費(fèi)下載
  10. 1.60 MB   |  7次下載  |  5 積分
  11. 6模擬電路仿真實(shí)現(xiàn)
  12. 2.94MB   |  4次下載  |  免費(fèi)
  13. 7PCB圖繪制實(shí)例操作
  14. 2.92MB   |  2次下載  |  免費(fèi)
  15. 8零死角玩轉(zhuǎn)STM32F103—指南者
  16. 26.78 MB   |  1次下載  |  1 積分

本月

  1. 1ADI高性能電源管理解決方案
  2. 2.43 MB   |  452次下載  |  免費(fèi)
  3. 2免費(fèi)開(kāi)源CC3D飛控資料(電路圖&PCB源文件、BOM、
  4. 5.67 MB   |  141次下載  |  1 積分
  5. 3基于STM32單片機(jī)智能手環(huán)心率計(jì)步器體溫顯示設(shè)計(jì)
  6. 0.10 MB   |  137次下載  |  免費(fèi)
  7. 4A7159和A7139射頻芯片的資料免費(fèi)下載
  8. 0.20 MB   |  55次下載  |  5 積分
  9. 5PIC12F629/675 數(shù)據(jù)手冊(cè)免費(fèi)下載
  10. 2.38 MB   |  36次下載  |  5 積分
  11. 6如何正確測(cè)試電源的紋波
  12. 0.36 MB   |  19次下載  |  免費(fèi)
  13. 7PIC16F716 數(shù)據(jù)手冊(cè)免費(fèi)下載
  14. 2.35 MB   |  18次下載  |  5 積分
  15. 8Q/SQR E8-4-2024乘用車(chē)電子電器零部件及子系統(tǒng)EMC試驗(yàn)方法及要求
  16. 1.97 MB   |  8次下載  |  10 積分

總榜

  1. 1matlab軟件下載入口
  2. 未知  |  935121次下載  |  10 積分
  3. 2開(kāi)源硬件-PMP21529.1-4 開(kāi)關(guān)降壓/升壓雙向直流/直流轉(zhuǎn)換器 PCB layout 設(shè)計(jì)
  4. 1.48MB  |  420062次下載  |  10 積分
  5. 3Altium DXP2002下載入口
  6. 未知  |  233088次下載  |  10 積分
  7. 4電路仿真軟件multisim 10.0免費(fèi)下載
  8. 340992  |  191367次下載  |  10 積分
  9. 5十天學(xué)會(huì)AVR單片機(jī)與C語(yǔ)言視頻教程 下載
  10. 158M  |  183335次下載  |  10 積分
  11. 6labview8.5下載
  12. 未知  |  81581次下載  |  10 積分
  13. 7Keil工具M(jìn)DK-Arm免費(fèi)下載
  14. 0.02 MB  |  73810次下載  |  10 積分
  15. 8LabVIEW 8.6下載
  16. 未知  |  65988次下載  |  10 積分
博九网百家乐官网游戏| 免费百家乐分析工具| 澳门百家乐官网赢钱秘| 大发888网| 百家乐策略| 百家乐官网视频聊天软件| 大发888体育在线| 百家乐斗地主在哪玩| 太阳城百家乐官网作弊| 九乐棋牌官网| 真人百家乐蓝盾娱乐网| 戒掉百家乐官网的玩法技巧和规则| 戰神国际娱乐城| 百家乐平注法到656| 开店做生意的风水| 天博百家乐官网娱乐城| 澳门美高梅娱乐| 乐天堂百家乐赌场娱乐网规则| 豪门百家乐官网的玩法技巧和规则| 百家乐官网路子分析| 大发888注册送58下载| 金木棉百家乐官网的玩法技巧和规则 | 真人百家乐官网蓝盾娱乐平台 | 百家乐官网冯氏坐庄法| 大发888娱乐城出纳柜台| 百家乐平技巧| 鼎尚百家乐官网的玩法技巧和规则 | 威廉希尔| 新时代百家乐的玩法技巧和规则 | 百家乐博彩金| 广州百家乐官网赌场娱乐网规则| 太阳城百家乐官网投注| 17pk棋牌官方下载| 威尼斯人娱乐城投注| 百家乐六合彩| 百家乐官网筹码14克粘土| 海威百家乐官网赌博机| 普安县| 钱柜娱乐城现金网| 大发888怎么刷钱| 百家乐23珠路打法|