网站打开速度与服务器,百度软件中心下载安装,网站如何被搜索到,外贸seo培训目录 Skipgram架构
代码开源声明
Pytorch复现Skip-gram
导包及随机种子设置
维基百科数据读取
建立词频元组列表并根据词频排序
建立词频字典,word_id字典,id_word字典
二次采样
正采样与负采样
Skipgram模型类
模型训练
词向量输出
近义词寻找
fasttext训练Skip-…目录 Skipgram架构
代码开源声明
Pytorch复现Skip-gram
导包及随机种子设置
维基百科数据读取
建立词频元组列表并根据词频排序
建立词频字典,word_id字典,id_word字典
二次采样
正采样与负采样
Skipgram模型类
模型训练
词向量输出
近义词寻找
fasttext训练Skip-gram Skipgram架构 初始论文中理论实现中,训练了两个参数矩阵,Word2vec中可以拆解为为词向量的降维矩阵和升维矩阵,初始使用独热编码对token进行序列标注,有图可以看出,由3*5的参数矩阵左乘5*1的词向量可以得到3*1的降维后的词向量,然后再由5*3的参数矩阵对降维后的词向量进行升维,与要预测的token进行损失计算
在实际实现中会采用隐式独热编码,也就是并不会手动通过独热编码进行词向量索引,比如语料库总共有5个token,对其进行独热编码后,由3维的独热编码来表示5个token,以下演示通过独热编码索引出词向量矩阵中对应token的词向量
import numpy as npnp.random.seed(0)y np.array([1,0,0,0,0])
x np.random.randn(5,3)
print(y)
print(x)
print(np.dot(y,x))
# [1 0 0 0 0]
# [[ 1.76405235 0.40015721 0.97873798]
# [ 2.2408932 1.86755799 -0.97727788]
# [ 0.95008842 -0.15135721 -0.10321885]
# [ 0.4105985 0.14404357 1.45427351]
# [ 0.76103773 0.12167502 0.44386323]]
# [1.76405235 0.40015721 0.97873798]
初始独热编码为1 0 0 0 0,通过左乘词向量矩阵可以索引到词向量矩阵的第一行,也就是一个token的词向量,
Word2vec一般分为Cbow以及Skip-gram,Skip-gram主要通过中间的token预测两侧的token,Skip-gram则是通过两侧的token预测中间的token.在理论实现中,例如Skip-gram就是通过取出中间 token的降维后的词向量再对其通过升维矩阵进行向量升维,与两侧token的原始独热编码进行损失计算
本文将进行Skip-gram的pytorch复现
在实际编码实现与理论实现具有一些区别,首先,实际编码实现中并不会显示创建独热编码进行词向量索引,而是直接通过embedding层来实现词向量矩阵的初始化和训练.
在理论实现上的损失计算是通过升维矩阵来进行与中间token的独热编码的损失计算
在实际实现上则有所不同
1.Skip-gram是通过中间预测两侧的结果,在实际是通过降维后的中间token的词向量,然后使用另一个降维矩阵对两侧token进行降维运算,最后通过降维后的中间token词向量和降维后的两侧token的词向量进行点乘计算用于计算相似度
2.对于token间的相似度,在实际实现中采用滑块的方式,我们会按序在语料中选择中间token(center_token),然后通过设置滑动窗口来进行两侧词的获取,
3.实际实现中还进行了负采样,也就是在中心token与相邻token进行点乘计算时,通常中心词与相邻token具有较高相似度,也就是点乘结果会越大,而与较远的token的相似度较低,点乘的结果也就会越小,在第2点中,提到的滑块就是用于选取相邻token的实现方式
3.在实际实现中可以选择性实现二次采样,用于随机删除高频词,因为高频词可能会对低频词的词向量学习产生影响
代码开源声明
本文包含的所有代码,数据集及训练完成的模型权重都可在下方的github链接中找到,如有需要使用训练好的模型权重及完整代码,可通过下方链接下载:
GitHub - Foxbabe1q/Pytorch_skipgram: Use pytorch to define skipgram model to train with wikipedia corpus. And I also use fasttexts skipgram to train the corpus
Pytorch复现Skip-gram
导包及随机种子设置
import io
import os
import sys
import requests
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pdnp.random.seed(42)
torch.manual_seed(42)
device torch.device(mps if torch.backends.mps.is_available() else cuda if torch.cuda.is_available() else cpu)
维基百科数据读取
def load_data():with open(fil9,r) as f:data f.read()print(data[:100])corpus data.split()print(corpus[:100])return corpusif __name__ __main__:corpus load_data()# anarchism originated as a term of abuse first used against early working class radicals including t
# [anarchism, originated, as, a, term, of, abuse, first, used, against, early, working, class, radicals, including, the, diggers, of, the, english, revolution, and, the, sans, culottes, of, the, french, revolution, whilst, the, term, is, still, used, in, a, pejorative, way, to, describe, any, act, that, used, violent, means, to, destroy, the, organization, of, society, it, has, also, been, taken, up, as, a, positive, label, by, self, defined, anarchists, the, word, anarchism, is, derived, from, the, greek, without, archons, ruler, chief, king, anarchism, as, a, political, philosophy, is, the, belief, that, rulers, are, unnecessary, and, should, be, abolished, although, there, are, differing]建立词频元组列表并根据词频排序
def build_word_freq_tuple(corpus):word_freq_dict {}for word in corpus:if word in word_freq_dict:word_freq_dict[word] 1elif word not in word_freq_dict:word_freq_dict[word] 1word_freq_tuple sorted(word_freq_dict.items(), keylambda x: x[1], reverseTrue)print(word_freq_tuple[:10])return word_freq_tupleif __name__ __main__:corpus load_data()word_freq_tuple build_word_freq_tuple(corpus)# [(the, 7446708), (of, 4453926), (one, 3776770), (zero, 3085174), (and, 2916968), (in, 2480552), (two, 2339802), (a, 2241744), (nine, 2063649), (to, 2028129)]建立词频字典,word_id字典,id_word字典
def convert_corpus_id(corpus, word_id_dict):id_corpus []for word in corpus:id_corpus.append(word_id_dict[word])print(corpus_size: , len(id_corpus))print(id_corpus[:20])return id_corpusif __name__ __main__:corpus load_data()word_freq_dict, word_id_dict, id_word_dict build_word_id_dict(corpus)id_corpus convert_corpus_id(corpus, word_id_dict)# vocabulary size: 833184
# corpus_size: 124301826
# [9558, 3423, 19, 7, 277, 1, 3451, 56, 82, 208, 174, 781, 500, 9838, 187, 0, 28373, 1, 0, 179]
这里可以看到语料总长度达到了1亿多词数,但是这个数量级的语料仍然较少,之后介绍的二次采样可以酌情选择是否选择,在语料较为不足的时候,二次采样可能产生相反效果
二次采样
二次采样用于通过删除一定数量的高频词来更好地训练低频词的词向量,公式如下 这里的指的是词频除总词数,t是一个阈值,通常为1e-5,t设置的越大,被删除的概率越小
为被删除的概率
def subsampling(corpus, word_freq_dict):corpus [word for word in corpus if not np.random.rand() (1 - (np.sqrt(1e-5 * len(corpus) / word_freq_dict[word])))]print(corpus_size after subsampling: , len(corpus))return corpusif __name__ __main__:corpus load_data()word_freq_dict, word_id_dict, id_word_dict build_word_id_dict(corpus)corpus subsampling(corpus, word_freq_dict)# corpus_size: 124301826
# vocabulary size: 833184
# corpus_size after subsampling: 83240619 正采样与负采样
def build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size 10, max_window_size 3):dataset []for center_word_idx, center_word in enumerate(corpus):window_size np.random.randint(1, max_window_size1)positive_range (max(0, center_word_idx - window_size), min(len(corpus) - 1, center_word_idx window_size))positive_samples [corpus[word_idx] for word_idx in range(positive_range[0], positive_range[1]1) if word_idx ! center_word_idx]for positive_sample in positive_samples:dataset.append((center_word, positive_sample, 1))sample_idx_list np.arange(len(word_id_dict))j corpus[positive_range[0]: positive_range[1]1]sample_idx_list np.delete(sample_idx_list, j)negative_samples np.random.choice(sample_idx_list, sizenegative_sample_size, replaceFalse)for negative_sample in negative_samples:dataset.append((center_word, negative_sample, 0))print(20 samples of the dataset)for i in range(20):print(center_word:, id_word_dict[dataset[i][0]], target_word:, id_word_dict[dataset[i][1]], label,dataset[i][2])return datasetif __name__ __main__:corpus load_data()word_freq_dict, word_id_dict, id_word_dict build_word_id_dict(corpus)corpus subsampling(corpus, word_freq_dict)corpus convert_corpus_id(corpus, word_id_dict)dataset build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size 10)# 20 samples of the dataset
# center_word: originated target_word: working label 1
# center_word: originated target_word: class label 1
# center_word: originated target_word: gulfs label 0
# center_word: originated target_word: propenents label 0
# center_word: originated target_word: pelletier label 0
# center_word: originated target_word: exclaiming label 0
# center_word: originated target_word: bod label 0
# center_word: originated target_word: liturgical label 0
# center_word: originated target_word: quattro label 0
# center_word: originated target_word: anatolius label 0
# center_word: originated target_word: interstratified label 0
# center_word: originated target_word: das label 0
# center_word: working target_word: originated label 1
# center_word: working target_word: class label 1
# center_word: working target_word: radicals label 1
# center_word: working target_word: clip label 0
# center_word: working target_word: moulting label 0
# center_word: working target_word: gnomon label 0
# center_word: working target_word: neural label 0
# center_word: working target_word: marsupial label 0
这里正采样选择中心词周围至多6个词作为与中心词语义强相关的词而在其它词中随机挑选10个词用于负采样强相关label为1负相关label为0
Skipgram模型类
class SkipGram(nn.Module):def __init__(self, vocab_size, embedding_size):super(SkipGram, self).__init__()self.vocab_size vocab_sizeself.embedding_size embedding_sizeself.embedding nn.Embedding(self.vocab_size, self.embedding_size)self.out_embedding nn.Embedding(self.vocab_size, self.embedding_size)init_range (1 / embedding_size) ** 0.5nn.init.uniform_(self.embedding.weight, -init_range, init_range)nn.init.uniform_(self.out_embedding.weight, -init_range, init_range)def forward(self, center_idx, target_idx, label):center_embedding self.embedding(center_idx)target_embedding self.embedding(target_idx)sim torch.mul(center_embedding, target_embedding)sim torch.sum(sim, dim1, keepdimFalse)loss F.binary_cross_entropy_with_logits(sim, label,reductionsum)return loss
这里使用第一个embedding矩阵作为最后的词向量矩阵,并且训练相关性使用词向量点乘值作为指标
模型训练
def train(vocab_size, dataset):my_skipgram SkipGram(vocab_size vocab_size, embedding_size300)my_skipgram.to(device)my_dataset create_dataset(dataset)my_dataloader DataLoader(my_dataset, batch_size64, shuffleTrue)optimizer optim.Adam(my_skipgram.parameters(), lr0.001)epochs 10loss_list []start_time time.time()for epoch in range(epochs):total_loss 0total_sample 0for center_idx, target_idx, label in my_dataloader:loss my_skipgram(center_idx, target_idx, label)optimizer.zero_grad()loss.backward()optimizer.step()total_loss loss.item()total_sample len(center_idx)print(fepoch: {epoch1}, loss {total_loss/total_sample}, time {time.time() - start_time : .2f})loss_list.append(total_loss/total_sample)plt.plot(np.arange(1, epochs 1),loss_list)plt.title(Loss_curve)plt.xlabel(Epoch)plt.ylabel(Loss)plt.xticks(np.arange(1, epochs 1))plt.savefig(loss_curve.png)plt.show()torch.save(my_skipgram.state_dict(), skip_gram.pt)if __name__ __main__:corpus load_data()word_freq_dict, word_id_dict, id_word_dict build_word_id_dict(corpus)corpus subsampling(corpus, word_freq_dict)corpus convert_corpus_id(corpus, word_id_dict)dataset build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size 10)train(len(word_id_dict), dataset)
这里训练只训练了10个epoch并且为了节约训练资源由于原语料长度超过一亿所以这里只选取长度为200万的语料进行训练
词向量输出
def predict(word, vocab_size, word_id_dict):if word not in word_id_dict:print(fWord {word} not found in the vocabulary.)return Nonemy_skipgram SkipGram(vocab_size vocab_size, embedding_size300)my_skipgram.load_state_dict(torch.load(skip_gram.pt))my_skipgram.to(device)my_skipgram.eval()word_id torch.tensor(word_id_dict[word], devicedevice, dtypetorch.int64)print(fPredicting the embedding vector for word {word}:\n{my_skipgram.embedding(word_id)})if __name__ __main__:corpus load_data()word_freq_dict, word_id_dict, id_word_dict build_word_id_dict(corpus)corpus subsampling(corpus, word_freq_dict)corpus convert_corpus_id(corpus, word_id_dict)dataset build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size 10)train(len(word_id_dict), dataset)predict(sport, len(word_id_dict), word_id_dict)
近义词寻找
def similarity(word, vocab_size, word_id_dict, id_word_dict, neighbors 5):if word not in word_id_dict:print(fWord {word} not found in the vocabulary.)return Nonemy_skipgram SkipGram(vocab_sizevocab_size, embedding_size300)my_skipgram.load_state_dict(torch.load(skip_gram.pt, weights_onlyTrue))my_skipgram.to(device)my_skipgram.eval()word_id torch.tensor(word_id_dict[word], devicedevice, dtypetorch.int64)word_embedding my_skipgram.embedding(word_id)similarity_score {}for idx in word_id_dict.values():other_word_embedding my_skipgram.embedding(torch.tensor(idx, devicedevice, dtypetorch.int64))sim torch.matmul(word_embedding, other_word_embedding)/(torch.norm(word_embedding, dim0, keepdimFalse) * torch.norm(other_word_embedding, dim0, keepdimFalse))similarity_score[id_word_dict[idx]] sim.item()nearest_neighbors sorted(similarity_score.items(), keylambda x: x[1], reverseTrue)[:5]print(nearest_neighbors)return nearest_neighborsif __name__ __main__:corpus load_data()word_freq_dict, word_id_dict, id_word_dict build_word_id_dict(corpus)corpus subsampling(corpus, word_freq_dict)corpus convert_corpus_id(corpus, word_id_dict)dataset build_negative_sampling_dataset(corpus, word_id_dict, id_word_dict, negative_sample_size 10)train(len(word_id_dict), dataset)predict(sport, len(word_id_dict), word_id_dict)similarity(sport, len(word_id_dict), word_id_dict, id_word_dict, neighbors 5)
这里查找近义词,会从词典中找到点乘值最大的5个词,个数可以通过修改neighbors更改
fasttext训练Skip-gram
fasttext训练的过程较为简单,该模型,包括还有CBOW都被集成在了模块中
import fasttextdef train():skipgram fasttext.train_unsupervised(fil9, model skipgram)skipgram.save_model(skipgram.bin)def skg_test1():skipgram fasttext.load_model(skipgram.bin)print(skipgram.get_word_vector(sport))print(skipgram.get_nearest_neighbors(sport))if __name__ __main__:train()skg_test1()# Read 124M words
# Number of words: 218316
# Number of labels: 0
# Progress: 100.0% words/sec/thread: 38918 lr: 0.000000 avg.loss: 1.071778 ETA: 0h 0m 0s
# Warning : load_model does not return WordVectorModel or SupervisedModel any more, but a FastText object which is very similar.
# [-1.1217905e-01 -2.1082790e-01 -5.0111616e-05 -7.6881155e-02
# -2.0150667e-01 -1.8065287e-01 1.3297442e-01 1.3444095e-02
# -1.5131533e-01 -2.5561339e-01 1.5086566e-01 -8.5557923e-02
# -2.1246003e-01 -8.0699474e-02 -1.5511900e-01 -2.4630783e-01
# 4.1686368e-01 8.0300289e-01 2.5104052e-01 -7.7809072e-01
# 2.2462079e-01 8.2177565e-02 1.7808667e-01 -3.3937061e-01
# 1.2025767e-01 9.7873092e-02 -3.8934144e-01 1.2671056e-01
# -2.7373591e-01 4.1039872e-01 -2.9629371e-01 4.4961619e-01
# 5.0581735e-02 -1.9909970e-01 1.0461334e-01 -4.9297757e-02
# -9.5666438e-02 1.6832566e-01 7.4807540e-02 6.5610033e-01
# -2.6710102e-01 2.5174522e-01 2.0871958e-01 -2.3539853e-01
# -1.0441781e-01 -3.5934374e-01 -2.0167212e-01 -6.7970419e-01
# -4.6956554e-02 9.3441598e-02 3.8153380e-01 2.0482899e-01
# 6.1529225e-01 -9.8463172e-01 -5.7401802e-02 -1.5414989e-01
# 6.7769766e-02 2.2661546e-01 -3.1193841e-02 3.8101819e-01
# -3.1099179e-01 -2.9264178e-02 2.0313324e-01 -3.6542088e-01
# -1.2520532e-01 1.8720575e-01 -2.6330149e-01 1.9312735e-01
# -5.1107663e-01 -2.5122452e-01 2.2448047e-01 -4.7734442e-01
# 2.5731093e-01 -1.4026532e-01 4.3919176e-02 -2.0015708e-01
# -2.8174376e-01 3.3095101e-01 1.0486527e-01 2.8560793e-01
# -2.4086323e-01 -9.3831137e-02 -1.9629408e-01 2.4319877e-01
# -1.8636097e-01 -3.9179447e-01 7.6361425e-02 1.6013722e-01
# -9.0249017e-02 -5.6596959e-01 4.8584041e-01 3.4663376e-01
# 2.6066643e-01 -7.1866415e-03 1.7896013e-01 -1.2109153e00
# -7.9120353e-02 7.6195911e-02 4.5524022e-01 -1.4492531e-01]
# [(0.849130392074585, sports), (0.8167348504066467, sporting), (0.8091928362846375, competitions), (0.7699509859085083, racing), (0.7655908465385437, sportsman), (0.7654882073402405, bobsledding), (0.7621665000915527, bobsleigh), (0.7620510458946228, motorsport), (0.7576955556869507, korfball), (0.7561532258987427, competiting)]