首頁人工智能技術資訊正文

使用Transformer構建語言模型

更新時間:2022-04-26 來源:黑馬程序員 瀏覽量:

什么是語言模型:

以一個符合語言規(guī)律的序列為輸入,模型將利用序列間關系等特征,輸出一個在所有詞匯上的概率分布.這樣的模型稱為語言模型。

# 語言模型的訓練語料一般來自于文章,對應的源文本和目標文本形如:
src1 = "I can do" tgt1 = "can do it"
src2 = "can do it", tgt2 = "do it <eos>"

語言模型能解決哪些問題:

1, 根據語言模型的定義,可以在它的基礎上完成機器翻譯,文本生成等任務,因為我們通過最后輸出的概率分布來預測下一個詞匯是什么.

2, 語言模型可以判斷輸入的序列是否為一句完整的話,因為我們可以根據輸出的概率分布查看最大概率是否落在句子結束符上,來判斷完整性.

3, 語言模型本身的訓練目標是預測下一個詞,因為它的特征提取部分會抽象很多語言序列之間的關系,這些關系可能同樣對其他語言類任務有效果.因此可以作為預訓練模型進行遷移學習.

整個案例的實現可分為以下五個步驟

第一步: 導入必備的工具包

第二步: 導入wikiText-2數據集并作基本處理

第三步: 構建用于模型輸入的批次化數據

第四步: 構建訓練和評估函數

第五步: 進行訓練和評估(包括驗證以及測試)

第一步: 導入必備的工具包

pytorch版本必須使用1.3.1, python版本使用3.6.x

pip install torch==1.3.1
# 數學計算工具包math
import math

# torch以及torch.nn, torch.nn.functional
import torch
import torch.nn as nn
import torch.nn.functional as F

# torch中經典文本數據集有關的工具包
# 具體詳情參考下方torchtext介紹
import torchtext

# torchtext中的數據處理工具, get_tokenizer用于英文分詞
from torchtext.data.utils import get_tokenizer

# 已經構建完成的TransformerModel
from pyitcast.transformer import TransformerModel

torchtext:它是torch工具中處理NLP問題的常用數據處理包.

torchtext的重要功能:對文本數據進行處理, 比如文本語料加載, 文本迭代器構建等.

包含很多經典文本語料的預加載方法. 其中包括的語料有:用于情感分析的SST和IMDB, 用于問題分類的TREC, 用于及其翻譯的 WMT14, IWSLT,以及用于語言模型任務wikiText-2, WikiText103, PennTreebank.

我們這里使用wikiText-2來訓練語言模型, 下面有關該數據集的相關詳情:

1650944025293_11.png

wikiText-2數據集的體量中等, 訓練集共有600篇短文, 共208萬左右的詞匯, 33278個不重復詞匯, OoV(有多少正常英文詞匯不在該數據集中的占比)為2.6%,數據集中的短文都是維基百科中對一些概念的介紹和描述.

第二步: 導入wikiText-2數據集并作基本處理

# 創(chuàng)建語料域, 語料域是存放語料的數據結構, 
# 它的四個參數代表給存放語料(或稱作文本)施加的作用. 
# 分別為 tokenize,使用get_tokenizer("basic_english")獲得一個分割器對象,
# 分割方式按照文本為基礎英文進行分割. 
# init_token為給文本施加的起始符 <sos>給文本施加的終止符<eos>, 
# 最后一個lower為True, 存放的文本字母全部小寫.
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)

# 最終獲得一個Field對象.
# <torchtext.data.field.Field object at 0x7fc42a02e7f0>

# 然后使用torchtext的數據集方法導入WikiText2數據, 
# 并切分為對應訓練文本, 驗證文本,測試文本, 并對這些文本施加剛剛創(chuàng)建的語料域.
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)

# 我們可以通過examples[0].text取出文本對象進行查看.
# >>> test_txt.examples[0].text[:10]
# ['<eos>', '=', 'robert', '<unk>', '=', '<eos>', '<eos>', 'robert', '<unk>', 'is']

# 將訓練集文本數據構建一個vocab對象, 
# 這樣可以使用vocab對象的stoi方法統(tǒng)計文本共包含的不重復詞匯總數.
TEXT.build_vocab(train_txt)

# 然后選擇設備cuda或者cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

該案例的所有代碼都將實現在一個transformer_lm.py文件中.

第三步: 構建用于模型輸入的批次化數據

批次化過程的第一個函數batchify代碼分析:

def batchify(data, bsz):
    """batchify函數用于將文本數據映射成連續(xù)數字, 并轉換成指定的樣式, 指定的樣式可參考下圖.
       它有兩個輸入參數, data就是我們之前得到的文本數據(train_txt, val_txt, test_txt),
       bsz是就是batch_size, 每次模型更新參數的數據量"""
    # 使用TEXT的numericalize方法將單詞映射成對應的連續(xù)數字.
    data = TEXT.numericalize([data.examples[0].text])
    # >>> data
    # tensor([[   3],
    #    [  12],
    #    [3852],
    #    ...,
    #    [   6],
    #    [   3],
    #    [   3]])

    # 接著用數據詞匯總數除以bsz,
    # 取整數得到一個nbatch代表需要多少次batch后能夠遍歷完所有數據
    nbatch = data.size(0) // bsz

    # 之后使用narrow方法對不規(guī)整的剩余數據進行刪除,
    # 第一個參數是代表橫軸刪除還是縱軸刪除, 0為橫軸,1為縱軸
    # 第二個和第三個參數代表保留開始軸到結束軸的數值.類似于切片
    # 可參考下方演示示例進行更深理解.
    data = data.narrow(0, 0, nbatch * bsz)
    # >>> data
    # tensor([[   3],
    #    [  12],
    #    [3852],
    #    ...,
    #    [  78],
    #    [ 299],
    #    [  36]])
    # 后面不能形成bsz個的一組數據被刪除

    # 接著我們使用view方法對data進行矩陣變換, 使其成為如下樣式:
    # tensor([[    3,    25,  1849,  ...,     5,    65,    30],
    #    [   12,    66,    13,  ...,    35,  2438,  4064],
    #    [ 3852, 13667,  2962,  ...,   902,    33,    20],
    #    ...,
    #    [  154,     7,    10,  ...,     5,  1076,    78],
    #    [   25,     4,  4135,  ...,     4,    56,   299],
    #    [    6,    57,   385,  ...,  3168,   737,    36]])
    # 因為會做轉置操作, 因此這個矩陣的形狀是[None, bsz],
    # 如果輸入是訓練數據的話,形狀為[104335, 20], 可以通過打印data.shape獲得.
    # 也就是data的列數是等于bsz的值的.
    data = data.view(bsz, -1).t().contiguous()
    # 最后將數據分配在指定的設備上.
    return data.to(device)

batchify的樣式轉化圖:

1650944305127_12.png

注:大寫字母A,B,C ... 代表句子中的每個單詞.

torch.narrow演示:

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> x.narrow(0, 0, 2)
tensor([[ 1,  2,  3],
        [ 4,  5,  6]])
>>> x.narrow(1, 1, 2)
tensor([[ 2,  3],
        [ 5,  6],
        [ 8,  9]])

接下來我們將使用batchify來處理訓練數據,驗證數據以及測試數據.

# 訓練數據的batch size
batch_size = 20

# 驗證和測試數據(統(tǒng)稱為評估數據)的batch size
eval_batch_size = 10

# 獲得train_data, val_data, test_data
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)

上面的分割批次并沒有進行源數據與目標數據的處理, 接下來我們將根據語言模型訓練的語料規(guī)定來構建源數據與目標數據.

語言模型訓練的語料規(guī)定:

如果源數據為句子ABCD, ABCD代表句子中的詞匯或符號, 則它的目標數據為BCDE, BCDE分別代表ABCD的下一個詞匯.

語言模型的訓練語料規(guī)定

如圖所示,我們這里的句子序列是豎著的, 而且我們發(fā)現如果用一個批次處理完所有數據, 以訓練數據為例, 每個句子長度高達104335, 這明顯是不科學的, 因此我們在這里要限定每個批次中的句子長度允許的最大值bptt.

批次化過程的第二個函數get_batch代碼分析:

# 令子長度允許的最大值bptt為35
bptt = 35

def get_batch(source, i):
    """用于獲得每個批次合理大小的源數據和目標數據.
       參數source是通過batchify得到的train_data/val_data/test_data.
       i是具體的批次次數.
    """

    # 首先我們確定句子長度, 它將是在bptt和len(source) - 1 - i中最小值
    # 實質上, 前面的批次中都會是bptt的值, 只不過最后一個批次中, 句子長度
    # 可能不夠bptt的35個, 因此會變?yōu)閘en(source) - 1 - i的值.
    seq_len = min(bptt, len(source) - 1 - i)

    # 語言模型訓練的源數據的第i批數據將是batchify的結果的切片[i:i+seq_len]
    data = source[i:i+seq_len]

    # 根據語言模型訓練的語料規(guī)定, 它的目標數據是源數據向后移動一位
    # 因為最后目標數據的切片會越界, 因此使用view(-1)來保證形狀正常.
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

輸入實例:

# 以測試集數據為例
source = test_data
i = 1

輸出效果:

data = tensor([[   12,  1053,   355,   134,    37,     7,     4,     0,   835,  9834],
        [  635,     8,     5,     5,   421,     4,    88,     8,   573,  2511],
        [    0,    58,     8,     8,     6,   692,   544,     0,   212,     5],
        [   12,     0,   105,    26,     3,     5,     6,     0,     4,    56],
        [    3, 16074, 21254,   320,     3,   262,    16,     6,  1087,    89],
        [    3,   751,  3866,    10,    12,    31,   246,   238,    79,    49],
        [  635,   943,    78,    36,    12,   475,    66,    10,     4,   924],
        [    0,  2358,    52,     4,    12,     4,     5,     0, 19831,    21],
        [   26,    38,    54,    40,  1589,  3729,  1014,     5,     8,     4],
        [   33, 17597,    33,  1661,    15,     7,     5,     0,     4,   170],
        [  335,   268,   117,     0,     0,     4,  3144,  1557,     0,   160],
        [  106,     4,  4706,  2245,    12,  1074,    13,  2105,     5,    29],
        [    5, 16074,    10,  1087,    12,   137,   251, 13238,     8,     4],
        [  394,   746,     4,     9,    12,  6032,     4,  2190,   303, 12651],
        [    8,   616,  2107,     4,     3,     4,   425,     0,    10,   510],
        [ 1339,   112,    23,   335,     3, 22251,  1162,     9,    11,     9],
        [ 1212,   468,     6,   820,     9,     7,  1231,  4202,  2866,   382],
        [    6,    24,   104,     6,     4,     4,     7,    10,     9,   588],
        [   31,   190,     0,     0,   230,   267,     4,   273,   278,     6],
        [   34,    25,    47,    26,  1864,     6,   694,     0,  2112,     3],
        [   11,     6,    52,   798,     8,    69,    20,    31,    63,     9],
        [ 1800,    25,  2141,  2442,   117,    31,   196,  7290,     4,   298],
        [   15,   171,    15,    17,  1712,    13,   217,    59,   736,     5],
        [ 4210,   191,   142,    14,  5251,   939,    59,    38, 10055, 25132],
        [  302,    23, 11718,    11,    11,   599,   382,   317,     8,    13],
        [   16,  1564,     9,  4808,     6,     0,     6,     6,     4,     4],
        [    4,     7,    39,     7,  3934,     5,     9,     3,  8047,   557],
        [  394,     0, 10715,  3580,  8682,    31,   242,     0, 10055,   170],
        [   96,     6,   144,  3403,     4,    13,  1014,    14,     6,  2395],
        [    4,     3, 13729,    14,    40,     0,     5,    18,   676,  3267],
        [ 1031,     3,     0,   628,  1589,    22, 10916, 10969,     5, 22548],
        [    9,    12,     6,    84,    15,    49,  3144,     7,   102,    15],
        [  916,    12,     4,   203,     0,   273,   303,   333,  4318,     0],
        [    6,    12,     0,  4842,     5,    17,     4,    47,  4138,  2072],
        [   38,   237,     5,    50,    35,    27, 18530,   244,    20,     6]])

target =  tensor([  635,     8,     5,     5,   421,     4,    88,     8,   573,  2511,
            0,    58,     8,     8,     6,   692,   544,     0,   212,     5,
           12,     0,   105,    26,     3,     5,     6,     0,     4,    56,
            3, 16074, 21254,   320,     3,   262,    16,     6,  1087,    89,
            3,   751,  3866,    10,    12,    31,   246,   238,    79,    49,
          635,   943,    78,    36,    12,   475,    66,    10,     4,   924,
            0,  2358,    52,     4,    12,     4,     5,     0, 19831,    21,
           26,    38,    54,    40,  1589,  3729,  1014,     5,     8,     4,
           33, 17597,    33,  1661,    15,     7,     5,     0,     4,   170,
          335,   268,   117,     0,     0,     4,  3144,  1557,     0,   160,
          106,     4,  4706,  2245,    12,  1074,    13,  2105,     5,    29,
            5, 16074,    10,  1087,    12,   137,   251, 13238,     8,     4,
          394,   746,     4,     9,    12,  6032,     4,  2190,   303, 12651,
            8,   616,  2107,     4,     3,     4,   425,     0,    10,   510,
         1339,   112,    23,   335,     3, 22251,  1162,     9,    11,     9,
         1212,   468,     6,   820,     9,     7,  1231,  4202,  2866,   382,
            6,    24,   104,     6,     4,     4,     7,    10,     9,   588,
           31,   190,     0,     0,   230,   267,     4,   273,   278,     6,
           34,    25,    47,    26,  1864,     6,   694,     0,  2112,     3,
           11,     6,    52,   798,     8,    69,    20,    31,    63,     9,
         1800,    25,  2141,  2442,   117,    31,   196,  7290,     4,   298,
           15,   171,    15,    17,  1712,    13,   217,    59,   736,     5,
         4210,   191,   142,    14,  5251,   939,    59,    38, 10055, 25132,
          302,    23, 11718,    11,    11,   599,   382,   317,     8,    13,
           16,  1564,     9,  4808,     6,     0,     6,     6,     4,     4,
            4,     7,    39,     7,  3934,     5,     9,     3,  8047,   557,
          394,     0, 10715,  3580,  8682,    31,   242,     0, 10055,   170,
           96,     6,   144,  3403,     4,    13,  1014,    14,     6,  2395,
            4,     3, 13729,    14,    40,     0,     5,    18,   676,  3267,
         1031,     3,     0,   628,  1589,    22, 10916, 10969,     5, 22548,
            9,    12,     6,    84,    15,    49,  3144,     7,   102,    15,
          916,    12,     4,   203,     0,   273,   303,   333,  4318,     0,
            6,    12,     0,  4842,     5,    17,     4,    47,  4138,  2072,
           38,   237,     5,    50,    35,    27, 18530,   244,    20,     6,
           13,  1083,    35,  1990,   653,    13,    10,    11,  1538,    56])

第四步: 構建訓練和評估函數

設置模型超參數和初始化模型

# 通過TEXT.vocab.stoi方法獲得不重復詞匯總數
ntokens = len(TEXT.vocab.stoi)

# 詞嵌入大小為200
emsize = 200

# 前饋全連接層的節(jié)點數
nhid = 200

# 編碼器層的數量
nlayers = 2

# 多頭注意力機制的頭數
nhead = 2

# 置0比率
dropout = 0.2

# 將參數輸入到TransformerModel中
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

# 模型初始化后, 接下來進行損失函數和優(yōu)化方法的選擇.

# 關于損失函數, 我們使用nn自帶的交叉熵損失
criterion = nn.CrossEntropyLoss()

# 學習率初始值定為5.0
lr = 5.0

# 優(yōu)化器選擇torch自帶的SGD隨機梯度下降方法, 并把lr傳入其中
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# 定義學習率調整方法, 使用torch自帶的lr_scheduler, 將優(yōu)化器傳入其中.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

模型訓練代碼分析:

# 導入時間工具包
import time

def train():
    """訓練函數"""
    # 模型開啟訓練模式
    model.train()
    # 定義初始損失為0
    total_loss = 0.
    # 獲得當前時間
    start_time = time.time()
    # 開始遍歷批次數據
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        # 通過get_batch獲得源數據和目標數據
        data, targets = get_batch(train_data, i)
        # 設置優(yōu)化器初始梯度為0梯度
        optimizer.zero_grad()
        # 將數據裝入model得到輸出
        output = model(data)
        # 將輸出和目標數據傳入損失函數對象
        loss = criterion(output.view(-1, ntokens), targets)
        # 損失進行反向傳播以獲得總的損失
        loss.backward()
        # 使用nn自帶的clip_grad_norm_方法進行梯度規(guī)范化, 防止出現梯度消失或爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        # 模型參數進行更新
        optimizer.step()
        # 將每層的損失相加獲得總的損失
        total_loss += loss.item()
        # 日志打印間隔定為200
        log_interval = 200
        # 如果batch是200的倍數且大于0,則打印相關日志
        if batch % log_interval == 0 and batch > 0:
            # 平均損失為總損失除以log_interval
            cur_loss = total_loss / log_interval
            # 需要的時間為當前時間減去開始時間
            elapsed = time.time() - start_time
            # 打印輪數, 當前批次和總批次, 當前學習率, 訓練速度(每豪秒處理多少批次),
            # 平均損失, 以及困惑度, 困惑度是衡量語言模型的重要指標, 它的計算方法就是
            # 對交叉熵平均損失取自然對數的底數.
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))

            # 每個批次結束后, 總損失歸0
            total_loss = 0
            # 開始時間取當前時間
            start_time = time.time()

模型評估代碼分析:

def evaluate(eval_model, data_source):
    """評估函數, 評估階段包括驗證和測試,
       它的兩個參數eval_model為每輪訓練產生的模型
       data_source代表驗證或測試數據集"""
    # 模型開啟評估模式
    eval_model.eval()
    # 總損失歸0
    total_loss = 0
    # 因為評估模式模型參數不變, 因此反向傳播不需要求導, 以加快計算
    with torch.no_grad():
        # 與訓練過程相同, 但是因為過程不需要打印信息, 因此不需要batch數
        for i in range(0, data_source.size(0) - 1, bptt):
            # 首先還是通過通過get_batch獲得驗證數據集的源數據和目標數據
            data, targets = get_batch(data_source, i)
            # 通過eval_model獲得輸出
            output = eval_model(data)
            # 對輸出形狀扁平化, 變?yōu)槿吭~匯的概率分布
            output_flat = output.view(-1, ntokens)
            # 獲得評估過程的總損失
            total_loss += criterion(output_flat, targets).item()
            # 計算平均損失
            cur_loss = total_loss / ((data_source.size(0) - 1) / bptt)            

    # 返回平均損失
    return cur_loss

第五步: 進行訓練和評估(包括驗證以及測試)

模型的訓練與驗證代碼分析:

# 首先初始化最佳驗證損失,初始值為無窮大
import copy
best_val_loss = float("inf")

# 定義訓練輪數
epochs = 3

# 定義最佳模型變量, 初始值為None
best_model = None

# 使用for循環(huán)遍歷輪數
for epoch in range(1, epochs + 1):
    # 首先獲得輪數開始時間
    epoch_start_time = time.time()
    # 調用訓練函數
    train()
    # 該輪訓練后我們的模型參數已經發(fā)生了變化
    # 將模型和評估數據傳入到評估函數中
    val_loss = evaluate(model, val_data)
    # 之后打印每輪的評估日志,分別有輪數,耗時,驗證損失以及驗證困惑度
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)
    # 我們將比較哪一輪損失最小,賦值給best_val_loss,
    # 并取該損失下的模型為best_model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        # 使用深拷貝,拷貝最優(yōu)模型
        best_model = copy.deepcopy(model)
    # 每輪都會對優(yōu)化方法的學習率做調整
    scheduler.step()

輸出效果:

| epoch   1 |   200/ 2981 batches | lr 5.00 | ms/batch 30.03 | loss  7.68 | ppl  2158.52
| epoch   1 |   400/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss  5.26 | ppl   193.39
| epoch   1 |   600/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss  4.07 | ppl    58.44
| epoch   1 |   800/ 2981 batches | lr 5.00 | ms/batch 28.88 | loss  3.41 | ppl    30.26
| epoch   1 |  1000/ 2981 batches | lr 5.00 | ms/batch 28.89 | loss  2.98 | ppl    19.72
| epoch   1 |  1200/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss  2.79 | ppl    16.30
| epoch   1 |  1400/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss  2.67 | ppl    14.38
| epoch   1 |  1600/ 2981 batches | lr 5.00 | ms/batch 28.92 | loss  2.58 | ppl    13.19
| epoch   1 |  1800/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss  2.43 | ppl    11.32
| epoch   1 |  2000/ 2981 batches | lr 5.00 | ms/batch 28.92 | loss  2.39 | ppl    10.93
| epoch   1 |  2200/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss  2.33 | ppl    10.24
| epoch   1 |  2400/ 2981 batches | lr 5.00 | ms/batch 28.91 | loss  2.36 | ppl    10.59
| epoch   1 |  2600/ 2981 batches | lr 5.00 | ms/batch 28.90 | loss  2.33 | ppl    10.31
| epoch   1 |  2800/ 2981 batches | lr 5.00 | ms/batch 28.92 | loss  2.26 | ppl     9.54
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 90.01s | valid loss  1.32 | valid ppl     3.73
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2981 batches | lr 4.75 | ms/batch 29.08 | loss  2.18 | ppl     8.83
| epoch   2 |   400/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss  2.11 | ppl     8.24
| epoch   2 |   600/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss  1.98 | ppl     7.23
| epoch   2 |   800/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss  2.00 | ppl     7.39
| epoch   2 |  1000/ 2981 batches | lr 4.75 | ms/batch 28.94 | loss  1.94 | ppl     6.96
| epoch   2 |  1200/ 2981 batches | lr 4.75 | ms/batch 28.92 | loss  1.97 | ppl     7.15
| epoch   2 |  1400/ 2981 batches | lr 4.75 | ms/batch 28.94 | loss  1.98 | ppl     7.28
| epoch   2 |  1600/ 2981 batches | lr 4.75 | ms/batch 28.92 | loss  1.97 | ppl     7.16
| epoch   2 |  1800/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss  1.92 | ppl     6.84
| epoch   2 |  2000/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss  1.96 | ppl     7.11
| epoch   2 |  2200/ 2981 batches | lr 4.75 | ms/batch 28.93 | loss  1.92 | ppl     6.80
| epoch   2 |  2400/ 2981 batches | lr 4.75 | ms/batch 28.94 | loss  1.94 | ppl     6.93
| epoch   2 |  2600/ 2981 batches | lr 4.75 | ms/batch 28.76 | loss  1.91 | ppl     6.76
| epoch   2 |  2800/ 2981 batches | lr 4.75 | ms/batch 28.75 | loss  1.89 | ppl     6.64
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 89.71s | valid loss  1.01 | valid ppl     2.74
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2981 batches | lr 4.51 | ms/batch 28.88 | loss  1.78 | ppl     5.96
| epoch   3 |   400/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss  1.89 | ppl     6.59
| epoch   3 |   600/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss  1.72 | ppl     5.58
| epoch   3 |   800/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss  1.73 | ppl     5.63
| epoch   3 |  1000/ 2981 batches | lr 4.51 | ms/batch 28.73 | loss  1.65 | ppl     5.22
| epoch   3 |  1200/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss  1.69 | ppl     5.40
| epoch   3 |  1400/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss  1.73 | ppl     5.66
| epoch   3 |  1600/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss  1.75 | ppl     5.73
| epoch   3 |  1800/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss  1.67 | ppl     5.33
| epoch   3 |  2000/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss  1.69 | ppl     5.41
| epoch   3 |  2200/ 2981 batches | lr 4.51 | ms/batch 28.74 | loss  1.66 | ppl     5.26
| epoch   3 |  2400/ 2981 batches | lr 4.51 | ms/batch 28.76 | loss  1.69 | ppl     5.43
| epoch   3 |  2600/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss  1.71 | ppl     5.55
| epoch   3 |  2800/ 2981 batches | lr 4.51 | ms/batch 28.75 | loss  1.72 | ppl     5.58
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 89.26s | valid loss  0.85 | valid ppl     2.33

模型測試代碼分析:

# 我們仍然使用evaluate函數,這次它的參數是best_model以及測試數據
test_loss = evaluate(best_model, test_data)

# 打印測試日志,包括測試損失和測試困惑度
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

輸出效果:

=========================================================================================
| End of training | test loss  0.83 | test ppl     2.30
=========================================================================================


分享到:
在線咨詢 我要報名
和我們在線交談!