基于LSTM的新冠预测,pytorch

    技术2026-04-16  5

    import torch from torch import nn import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.read_csv('./COVID-19_USA.csv') df.head() countryEnglishNameconfirmedCountnowconfirmedCountcuredCountdeadCountdeadRatecuredRate0United States of America00000.00.01United States of America00000.00.02United States of America00000.00.03United States of America00000.00.04United States of America00000.00.0 value = df['confirmedCount'].values[20:140] value array([ 13, 13, 14, 15, 15, 15, 15, 15, 15, 15, 15, 35, 35, 35, 53, 57, 60, 60, 63, 63, 77, 100, 122, 153, 232, 324, 445, 572, 704, 1004, 1004, 1635, 2084, 2885, 3700, 4661, 6420, 10259, 14250, 19624, 26997, 35360, 46450, 55243, 69194, 85840, 105470, 124686, 143101, 164603, 189633, 216722, 245658, 278537, 312245, 337971, 368449, 399979, 432579, 467184, 501615, 530830, 558526, 583220, 609696, 640014, 672246, 706832, 735366, 760570, 788920, 826184, 843937, 870468, 907096, 940797, 967585, 989357, 1014568, 1040608, 1070032, 1107815, 1133069, 1158341, 1181885, 1206323, 1231943, 1256972, 1286833, 1312099, 1332411, 1351200, 1371395, 1395265, 1419998, 1446875, 1470199, 1490195, 1510988, 1528568, 1577758, 1604189, 1626258, 1646495, 1665882, 1684173, 1702911, 1724873, 1750203, 1773020, 1792512, 1812125, 1832412, 1854476, 1872660, 1901391, 1920552, 1941748], dtype=int64) print(len(value)) x = [] y = [] seq = 3 for i in range(len(value)-seq-1): x.append(value[i:i+seq]) y.append(value[i+seq]) 118

    LSTM 的输入:input,(h_0,c_0)

    input:输入数据,shape 为(句子长度seq_len, 句子数量batch, 每个单词向量的长度input_size); h_0:默认为0,shape 为(num_layers * num_directions单向为1双向为2, batch, 隐藏层节点数hidden_size); c_0:默认为0,shape 为(num_layers * num_directions, batch, hidden_size);

    LSTM 的输出:output,(h_n,c_n)

    output:输出的 shape 为(seq_len, batch, num_directions * hidden_size); h_n:shape 为(num_layers * num_directions, batch, hidden_size); c_n:shape 为(num_layers * num_directions, batch, hidden_size);

    import torch import torch.nn as nn

    rnn = nn.LSTM(10, 20, 3) # 一个单词向量长度为10,隐藏层节点数为20,LSTM数量为3 input = torch.randn(8, 3, 10) # batch_size为3(输入数据有3个句子),每个句子有8个单词,每个单词向量长度为10 h_0, c_0 = torch.randn(3, 3, 20), torch.randn(3, 3, 20) output, (h_n, c_n) = rnn(input, (h_0, c_0))

    print(“input.shape:”, input.shape) print(“h_0.shape:”, h_0.shape) print(“c_0.shape:”, c_0.shape) print("*" * 50) print(“output.shape:”, output.shape) print(“h_n.shape:”, h_n.shape) print(“c_n.shape:”, c_n.shape)

    #print(x, '\n', y) train_x = (torch.tensor(x[0:90]).float()/100000.).reshape(-1, seq, 1) train_y = (torch.tensor(y[0:90]).float()/100000.).reshape(-1, 1) test_x = (torch.tensor(x[90:110]).float()/100000.).reshape(-1, seq, 1) test_y = (torch.tensor(y[90:110]).float()/100000.).reshape(-1, 1) print(test_y) tensor([[13.9527], [14.2000], [14.4688], [14.7020], [14.9019], [15.1099], [15.2857], [15.7776], [16.0419], [16.2626], [16.4650], [16.6588], [16.8417], [17.0291], [17.2487], [17.5020], [17.7302], [17.9251], [18.1213], [18.3241]]) # 模型训练 class LSTM(nn.Module): def __init__(self): super(LSTM, self).__init__() self.lstm = nn.LSTM(input_size=1, hidden_size=16, num_layers=1, batch_first=True) self.linear = nn.Linear(16 * seq, 1) def forward(self, x): x, (h, c) = self.lstm(x) x = x.reshape(-1, 16 * seq) x = self.linear(x) return x # 模型训练 model = LSTM() optimzer = torch.optim.Adam(model.parameters(), lr=0.005) loss_func = nn.MSELoss() model.train() for epoch in range(2000): output = model(train_x) loss = loss_func(output, train_y) optimzer.zero_grad() loss.backward() optimzer.step() if epoch % 20 == 0: tess_loss = loss_func(model(test_x), test_y) print("epoch:{}, train_loss:{}, test_loss:{}".format(epoch, loss, tess_loss)) epoch:0, train_loss:37.56373977661133, test_loss:263.2969055175781 epoch:20, train_loss:18.516111373901367, test_loss:166.8538360595703 epoch:40, train_loss:3.2988765239715576, test_loss:58.48737716674805 epoch:60, train_loss:1.0975019931793213, test_loss:28.006675720214844 epoch:80, train_loss:0.2535986304283142, test_loss:16.773616790771484 epoch:100, train_loss:0.10097146779298782, test_loss:12.012392044067383 epoch:120, train_loss:0.05068375915288925, test_loss:9.413983345031738 epoch:140, train_loss:0.031387731432914734, test_loss:7.986786842346191 epoch:160, train_loss:0.021674150601029396, test_loss:7.1016082763671875 epoch:180, train_loss:0.015867270529270172, test_loss:6.489916801452637 epoch:200, train_loss:0.01209135353565216, test_loss:6.03498649597168 epoch:220, train_loss:0.009504670277237892, test_loss:5.680188179016113 epoch:240, train_loss:0.00768320681527257, test_loss:5.396552562713623 epoch:260, train_loss:0.0064044492319226265, test_loss:5.167068004608154 epoch:280, train_loss:0.005527254659682512, test_loss:4.980656623840332 epoch:300, train_loss:0.004928725305944681, test_loss:4.828995227813721 epoch:320, train_loss:0.004502041265368462, test_loss:4.704446315765381 epoch:340, train_loss:0.004170786123722792, test_loss:4.600121974945068 epoch:360, train_loss:0.0038912463933229446, test_loss:4.5099616050720215 epoch:380, train_loss:0.0036443774588406086, test_loss:4.428786277770996 epoch:400, train_loss:0.0034239296801388264, test_loss:4.352680683135986 epoch:420, train_loss:0.003226183820515871, test_loss:4.279394626617432 epoch:440, train_loss:0.0030467547476291656, test_loss:4.208122730255127 epoch:460, train_loss:0.002882403088733554, test_loss:4.1388139724731445 epoch:480, train_loss:0.0027316079940646887, test_loss:4.071575164794922 epoch:500, train_loss:0.002594325691461563, test_loss:4.006669044494629 epoch:520, train_loss:0.0024710919242352247, test_loss:3.9445641040802 epoch:540, train_loss:0.0023619714193046093, test_loss:3.8857200145721436 epoch:560, train_loss:0.0022660386748611927, test_loss:3.8303375244140625 epoch:580, train_loss:0.0021817032247781754, test_loss:3.77828311920166 epoch:600, train_loss:0.002107286360114813, test_loss:3.729255199432373 epoch:620, train_loss:0.0020413610618561506, test_loss:3.682926654815674 epoch:640, train_loss:0.001982747344300151, test_loss:3.639024496078491 epoch:660, train_loss:0.0019304585875943303, test_loss:3.597336769104004 epoch:680, train_loss:0.001883711782284081, test_loss:3.5576891899108887 epoch:700, train_loss:0.0018417979590594769, test_loss:3.5199549198150635 epoch:720, train_loss:0.0018041135044768453, test_loss:3.4840168952941895 epoch:740, train_loss:0.001770111615769565, test_loss:3.4497766494750977 epoch:760, train_loss:0.0017393192974850535, test_loss:3.41713809967041 epoch:780, train_loss:0.0017113107023760676, test_loss:3.386018753051758 epoch:800, train_loss:0.0016857198206707835, test_loss:3.3563263416290283 epoch:820, train_loss:0.0016621954273432493, test_loss:3.3279783725738525 epoch:840, train_loss:0.0016404724447056651, test_loss:3.3008854389190674 epoch:860, train_loss:0.0016202878905460238, test_loss:3.2749710083007812 epoch:880, train_loss:0.001601424883119762, test_loss:3.2501537799835205 epoch:900, train_loss:0.0015837030950933695, test_loss:3.226363182067871 epoch:920, train_loss:0.0015669644344598055, test_loss:3.2035205364227295 epoch:940, train_loss:0.0015510818921029568, test_loss:3.181558132171631 epoch:960, train_loss:0.001535946037620306, test_loss:3.1604182720184326 epoch:980, train_loss:0.001521461526863277, test_loss:3.140035629272461 epoch:1000, train_loss:0.0015075404662638903, test_loss:3.1203596591949463 epoch:1020, train_loss:0.0014941382687538862, test_loss:3.101344108581543 epoch:1040, train_loss:0.001481189508922398, test_loss:3.082942008972168 epoch:1060, train_loss:0.0014686313224956393, test_loss:3.0651092529296875 epoch:1080, train_loss:0.0014564390294253826, test_loss:3.0478177070617676 epoch:1100, train_loss:0.001444563502445817, test_loss:3.0310306549072266 epoch:1120, train_loss:0.0014329585246741772, test_loss:3.014721632003784 epoch:1140, train_loss:0.0014216136187314987, test_loss:2.998866319656372 epoch:1160, train_loss:0.0014104940928518772, test_loss:2.9834399223327637 epoch:1180, train_loss:0.0013995792251080275, test_loss:2.968425989151001 epoch:1200, train_loss:0.001388857257552445, test_loss:2.953805923461914 epoch:1220, train_loss:0.0013782993191853166, test_loss:2.939566135406494 epoch:1240, train_loss:0.0013679120456799865, test_loss:2.9256908893585205 epoch:1260, train_loss:0.001357687870040536, test_loss:2.9121646881103516 epoch:1280, train_loss:0.0013476117746904492, test_loss:2.8989791870117188 epoch:1300, train_loss:0.0013376886490732431, test_loss:2.8861305713653564 epoch:1320, train_loss:0.001327910111285746, test_loss:2.873605728149414 epoch:1340, train_loss:0.0013182887341827154, test_loss:2.861402750015259 epoch:1360, train_loss:0.0013088179985061288, test_loss:2.8495066165924072 epoch:1380, train_loss:0.0012995086144655943, test_loss:2.837918519973755 epoch:1400, train_loss:0.0012903454480692744, test_loss:2.8266289234161377 epoch:1420, train_loss:0.0012813681969419122, test_loss:2.81563663482666 epoch:1440, train_loss:0.0012725305277854204, test_loss:2.8049354553222656 epoch:1460, train_loss:0.0012638678308576345, test_loss:2.7945163249969482 epoch:1480, train_loss:0.0012553795240819454, test_loss:2.7843799591064453 epoch:1500, train_loss:0.0012470469810068607, test_loss:2.774510622024536 epoch:1520, train_loss:0.001238883938640356, test_loss:2.764907121658325 epoch:1540, train_loss:0.0012308855075389147, test_loss:2.7555596828460693 epoch:1560, train_loss:0.0012230485444888473, test_loss:2.7464544773101807 epoch:1580, train_loss:0.0012153665302321315, test_loss:2.737583637237549 epoch:1600, train_loss:0.0012078447034582496, test_loss:2.7289376258850098 epoch:1620, train_loss:0.0012004825985059142, test_loss:2.720499277114868 epoch:1640, train_loss:0.001193271717056632, test_loss:2.7122607231140137 epoch:1660, train_loss:0.0011862049577757716, test_loss:2.7042083740234375 epoch:1680, train_loss:0.001179289072751999, test_loss:2.6963284015655518 epoch:1700, train_loss:0.00117252126801759, test_loss:2.6886096000671387 epoch:1720, train_loss:0.0011658959556370974, test_loss:2.681042432785034 epoch:1740, train_loss:0.001159411738626659, test_loss:2.673621892929077 epoch:1760, train_loss:0.0011530747869983315, test_loss:2.6663315296173096 epoch:1780, train_loss:0.0011468705488368869, test_loss:2.6591718196868896 epoch:1800, train_loss:0.0011408105492591858, test_loss:2.6521382331848145 epoch:1820, train_loss:0.001134884194470942, test_loss:2.6452348232269287 epoch:1840, train_loss:0.0011290841503068805, test_loss:2.638458728790283 epoch:1860, train_loss:0.001123423338867724, test_loss:2.631808280944824 epoch:1880, train_loss:0.0011178869754076004, test_loss:2.625295877456665 epoch:1900, train_loss:0.0011124806478619576, test_loss:2.618925094604492 epoch:1920, train_loss:0.001107192481867969, test_loss:2.6127049922943115 epoch:1940, train_loss:0.0011020202655345201, test_loss:2.6066269874572754 epoch:1960, train_loss:0.001096969353966415, test_loss:2.600712299346924 epoch:1980, train_loss:0.0010920269414782524, test_loss:2.594951629638672 model.eval() prediction = list((model(train_x).data.reshape(-1))*100000) + list((model(test_x).data.reshape(-1))*100000) plt.plot(value[3:], label='True Value') plt.plot(prediction[:91], label='LSTM fit') plt.plot(np.arange(90, 110, 1), prediction[90:], label='LSTM pred') print(len(value[3:])) print(len(prediction[90:])) plt.legend(loc='best') plt.title('Cumulative infections prediction(USA)') plt.xlabel('Day') plt.ylabel('Cumulative Cases') plt.show() 115 20

    df_2 = pd.read_csv('./COVID-19_China.csv') df_2.head() countryEnglishNameconfirmedCountnowconfirmedCountcuredCountdeadCountdeadRatecuredRate0China54449928173.1250005.1470591China63959230172.6604074.6948362China90183936262.8856833.9955603China1377129739412.9774872.8322444China2076197149562.6974952.360308 value = df_2['confirmedCount'].values[20:140] value array([42747, 44765, 59907, 63950, 66581, 68595, 70644, 72533, 74284, 74680, 75571, 76396, 77048, 77269, 77785, 78195, 78631, 78962, 79394, 79972, 80175, 80303, 80424, 80581, 80734, 80815, 80868, 80905, 80932, 80969, 80981, 80995, 81029, 81062, 81099, 81135, 81202, 81264, 81385, 81457, 81566, 81691, 81806, 81896, 82034, 82164, 82282, 82420, 82504, 82600, 82690, 82771, 82857, 82898, 82965, 83038, 83094, 83188, 83263, 83323, 83399, 83522, 83606, 83699, 83751, 83798, 84155, 84185, 84225, 84239, 84278, 84294, 84305, 84313, 84330, 84338, 84341, 84367, 84369, 84373, 84387, 84391, 84393, 84403, 84404, 84407, 84414, 84416, 84416, 84434, 84450, 84451, 84461, 84465, 84471, 84478, 84487, 84494, 84503, 84506, 84522, 84522, 84525, 84536, 84543, 84545, 84547, 84561, 84569, 84572, 84593, 84603, 84602, 84609, 84617, 84624, 84630, 84634], dtype=int64) print(len(value)) x = [] y = [] seq = 3 for i in range(len(value)-seq-1): x.append(value[i:i+seq]) y.append(value[i+seq]) 118 train_x = (torch.tensor(x[0:90]).float()/100000.).reshape(-1, seq, 1) train_y = (torch.tensor(y[0:90]).float()/100000.).reshape(-1, 1) test_x = (torch.tensor(x[90:110]).float()/100000.).reshape(-1, seq, 1) test_y = (torch.tensor(y[90:110]).float()/100000.).reshape(-1, 1) print(test_y) tensor([[0.8446], [0.8447], [0.8448], [0.8449], [0.8449], [0.8450], [0.8451], [0.8452], [0.8452], [0.8453], [0.8454], [0.8454], [0.8454], [0.8455], [0.8456], [0.8457], [0.8457], [0.8459], [0.8460], [0.8460]]) # 模型训练 model = LSTM() optimzer = torch.optim.Adam(model.parameters(), lr=0.005) loss_func = nn.MSELoss() model.train() for epoch in range(2000): output = model(train_x) loss = loss_func(output, train_y) optimzer.zero_grad() loss.backward() optimzer.step() if epoch % 20 == 0: tess_loss = loss_func(model(test_x), test_y) print("epoch:{}, train_loss:{}, test_loss:{}".format(epoch, loss, tess_loss)) epoch:0, train_loss:0.3825637102127075, test_loss:0.3793776333332062 epoch:20, train_loss:0.029308181256055832, test_loss:0.01525117177516222 epoch:40, train_loss:0.0014625154435634613, test_loss:0.001359123969450593 epoch:60, train_loss:0.0011256280122324824, test_loss:0.001867746701464057 epoch:80, train_loss:0.0008417891804128885, test_loss:0.0004369680245872587 epoch:100, train_loss:0.0008008113945834339, test_loss:0.0004433710710145533 epoch:120, train_loss:0.0007676983368583024, test_loss:0.0005258770543150604 epoch:140, train_loss:0.0007398930029012263, test_loss:0.0005360021023079753 epoch:160, train_loss:0.0007106903940439224, test_loss:0.0005205624038353562 epoch:180, train_loss:0.000680193246807903, test_loss:0.0005014035850763321 epoch:200, train_loss:0.0006485295598395169, test_loss:0.0004820456088054925 epoch:220, train_loss:0.0006157823954708874, test_loss:0.0004624216235242784 epoch:240, train_loss:0.0005820132209919393, test_loss:0.0004422594793140888 epoch:260, train_loss:0.0005472666234709322, test_loss:0.0004214201180730015 epoch:280, train_loss:0.000511578400619328, test_loss:0.0003998256870545447 epoch:300, train_loss:0.00047498263302259147, test_loss:0.00037742816493846476 epoch:320, train_loss:0.0004375250719022006, test_loss:0.00035419128835201263 epoch:340, train_loss:0.00039927675970830023, test_loss:0.00033008967875503004 epoch:360, train_loss:0.00036035984521731734, test_loss:0.0003051073872484267 epoch:380, train_loss:0.00032097692019306123, test_loss:0.0002792503801174462 epoch:400, train_loss:0.0002814592735376209, test_loss:0.00025257302331738174 epoch:420, train_loss:0.0002423272526357323, test_loss:0.00022522213112097234 epoch:440, train_loss:0.0002043575659627095, test_loss:0.00019748158229049295 epoch:460, train_loss:0.00016863604832906276, test_loss:0.000169840976013802 epoch:480, train_loss:0.00013653359201271087, test_loss:0.00014304733485914767 epoch:500, train_loss:0.0001095230836654082, test_loss:0.00011809338320745155 epoch:520, train_loss:8.876808715285733e-05, test_loss:9.60855686571449e-05 epoch:540, train_loss:7.458930485881865e-05, test_loss:7.79658803367056e-05 epoch:560, train_loss:6.617035978706554e-05, test_loss:6.417164695449173e-05 epoch:580, train_loss:6.185458914842457e-05, test_loss:5.449241871247068e-05 epoch:600, train_loss:5.988311022520065e-05, test_loss:4.820094909518957e-05 epoch:620, train_loss:5.898631934542209e-05, test_loss:4.4363230699673295e-05 epoch:640, train_loss:5.848926957696676e-05, test_loss:4.211986743030138e-05 epoch:660, train_loss:5.811382652609609e-05, test_loss:4.081864972249605e-05 epoch:680, train_loss:5.7767596445046365e-05, test_loss:4.003203866886906e-05 epoch:700, train_loss:5.742487701354548e-05, test_loss:3.9503516745753586e-05 epoch:720, train_loss:5.707951640943065e-05, test_loss:3.909639417543076e-05 epoch:740, train_loss:5.6730175856500864e-05, test_loss:3.8741745811421424e-05 epoch:760, train_loss:5.637647700496018e-05, test_loss:3.8406542444135994e-05 epoch:780, train_loss:5.601833254331723e-05, test_loss:3.807685061474331e-05 epoch:800, train_loss:5.56554296053946e-05, test_loss:3.774597280425951e-05 epoch:820, train_loss:5.5287899158429354e-05, test_loss:3.741172986337915e-05 epoch:840, train_loss:5.491534466273151e-05, test_loss:3.707344876602292e-05 epoch:860, train_loss:5.4537584219360724e-05, test_loss:3.673039100249298e-05 epoch:880, train_loss:5.4154432291397825e-05, test_loss:3.638227281044237e-05 epoch:900, train_loss:5.3765703341923654e-05, test_loss:3.603072764235549e-05 epoch:920, train_loss:5.3371091780718416e-05, test_loss:3.567387830116786e-05 epoch:940, train_loss:5.297027018968947e-05, test_loss:3.53103423549328e-05 epoch:960, train_loss:5.2563085773726925e-05, test_loss:3.494161501294002e-05 epoch:980, train_loss:5.214917837292887e-05, test_loss:3.456864942563698e-05 epoch:1000, train_loss:5.17280786880292e-05, test_loss:3.418769483687356e-05 epoch:1020, train_loss:5.129948112880811e-05, test_loss:3.3801185054471716e-05 epoch:1040, train_loss:5.086324381409213e-05, test_loss:3.34087380906567e-05 epoch:1060, train_loss:5.041856275056489e-05, test_loss:3.300896059954539e-05 epoch:1080, train_loss:4.9965601647272706e-05, test_loss:3.2601448765490204e-05 epoch:1100, train_loss:4.950328002450988e-05, test_loss:3.2187974284170195e-05 epoch:1120, train_loss:4.9031306843971834e-05, test_loss:3.1764171581016853e-05 epoch:1140, train_loss:4.854957660427317e-05, test_loss:3.133570862701163e-05 epoch:1160, train_loss:4.805713615496643e-05, test_loss:3.089520396315493e-05 epoch:1180, train_loss:4.7553505282849073e-05, test_loss:3.044824188691564e-05 epoch:1200, train_loss:4.7038094635354355e-05, test_loss:2.999177922902163e-05 epoch:1220, train_loss:4.651049312087707e-05, test_loss:2.9524149795179255e-05 epoch:1240, train_loss:4.596983490046114e-05, test_loss:2.904612665588502e-05 epoch:1260, train_loss:4.541549060377292e-05, test_loss:2.8558717531268485e-05 epoch:1280, train_loss:4.484672172111459e-05, test_loss:2.8060558179276995e-05 epoch:1300, train_loss:4.4262782466830686e-05, test_loss:2.75501352007268e-05 epoch:1320, train_loss:4.3662937969202176e-05, test_loss:2.7027786927646957e-05 epoch:1340, train_loss:4.3046358769061044e-05, test_loss:2.649312409630511e-05 epoch:1360, train_loss:4.241224087309092e-05, test_loss:2.5944673325284384e-05 epoch:1380, train_loss:4.175960930297151e-05, test_loss:2.5382616513525136e-05 epoch:1400, train_loss:4.108754365006462e-05, test_loss:2.4807826775941066e-05 epoch:1420, train_loss:4.039523264509626e-05, test_loss:2.4217149984906428e-05 epoch:1440, train_loss:3.968160672229715e-05, test_loss:2.3612001314177178e-05 epoch:1460, train_loss:3.8945843698456883e-05, test_loss:2.2989526769379154e-05 epoch:1480, train_loss:3.818666664301418e-05, test_loss:2.2353377062245272e-05 epoch:1500, train_loss:3.7403442547656596e-05, test_loss:2.1699825083487667e-05 epoch:1520, train_loss:3.659499634522945e-05, test_loss:2.1029607523814775e-05 epoch:1540, train_loss:3.576060043997131e-05, test_loss:2.034137287409976e-05 epoch:1560, train_loss:3.489920709398575e-05, test_loss:1.9637136574601755e-05 epoch:1580, train_loss:3.4010074159596115e-05, test_loss:1.8915245163952932e-05 epoch:1600, train_loss:3.3092539524659514e-05, test_loss:1.8175118384533562e-05 epoch:1620, train_loss:3.214598837075755e-05, test_loss:1.741875894367695e-05 epoch:1640, train_loss:3.1170118745649233e-05, test_loss:1.6644184142933227e-05 epoch:1660, train_loss:3.016489245055709e-05, test_loss:1.5854122466407716e-05 epoch:1680, train_loss:2.9130327675375156e-05, test_loss:1.5049477951833978e-05 epoch:1700, train_loss:2.8066970116924495e-05, test_loss:1.4228941836336162e-05 epoch:1720, train_loss:2.6975792934536003e-05, test_loss:1.3398408555076458e-05 epoch:1740, train_loss:2.5858256776700728e-05, test_loss:1.255551069334615e-05 epoch:1760, train_loss:2.4716307962080464e-05, test_loss:1.170479299617e-05 epoch:1780, train_loss:2.355259857722558e-05, test_loss:1.0850737453438342e-05 epoch:1800, train_loss:2.2370682927430607e-05, test_loss:9.995957043429371e-06 epoch:1820, train_loss:2.117482290486805e-05, test_loss:9.143737770500593e-06 epoch:1840, train_loss:1.9970017092418857e-05, test_loss:8.297822205349803e-06 epoch:1860, train_loss:1.8762442778097466e-05, test_loss:7.465289854735602e-06 epoch:1880, train_loss:1.7558993931743316e-05, test_loss:6.651775947830174e-06 epoch:1900, train_loss:1.6367670468753204e-05, test_loss:5.863402748218505e-06 epoch:1920, train_loss:1.519708621344762e-05, test_loss:5.106710432301043e-06 epoch:1940, train_loss:1.4056418876862153e-05, test_loss:4.386639830045169e-06 epoch:1960, train_loss:1.2955445527040865e-05, test_loss:3.7117310967005324e-06 epoch:1980, train_loss:1.1903614904440474e-05, test_loss:3.087144477831316e-06 model.eval() prediction = list((model(train_x).data.reshape(-1))*100000) + list((model(test_x).data.reshape(-1))*100000) plt.plot(value[3:], label='True Value') plt.plot(prediction[:91], label='LSTM fit') plt.plot(np.arange(90, 110, 1), prediction[90:], label='LSTM pred') print(len(value[3:])) print(len(prediction[90:])) plt.legend(loc='best') plt.title('Cumulative infections prediction(China)') plt.xlabel('Day') plt.ylabel('Cumulative Cases') plt.show() 115 20

    import torch from torch.autograd import Variable import torch.nn as nn from graphviz import Digraph from graphviz import Graph def make_dot(var, params=None): """ Produces Graphviz representation of PyTorch autograd graph Blue nodes are the Variables that require grad, orange are Tensors saved for backward in torch.autograd.Function Args: var: output Variable params: dict of (name, Variable) to add names to node that require grad (TODO: make optional) """ if params is not None: assert isinstance(params.values()[0], Variable) param_map = {id(v): k for k, v in params.items()} node_attr = dict(style='filled', shape='box', align='left', fontsize='12', ranksep='0.1', height='0.2') dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) seen = set() def size_to_str(size): return '('+(', ').join(['%d' % v for v in size])+')' def add_nodes(var): if var not in seen: if torch.is_tensor(var): dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') elif hasattr(var, 'variable'): u = var.variable name = param_map[id(u)] if params is not None else '' node_name = '%s\n %s' % (name, size_to_str(u.size())) dot.node(str(id(var)), node_name, fillcolor='lightblue') else: dot.node(str(id(var)), str(type(var).__name__)) seen.add(var) if hasattr(var, 'next_functions'): for u in var.next_functions: if u[0] is not None: dot.edge(str(id(u[0])), str(id(var))) add_nodes(u[0]) if hasattr(var, 'saved_tensors'): for t in var.saved_tensors: dot.edge(str(id(t)), str(id(var))) add_nodes(t) add_nodes(var.grad_fn) return dot net = LSTM() x = train_x y = net(train_x) g = make_dot(y) g.view() 'Digraph.gv.pdf' #g.view(quiet=True,quiet_view=True) params = list(net.parameters()) k = 0 for i in params: l = 1 print("该层的结构:" + str(list(i.size()))) for j in i.size(): l *= j print("该层参数和:" + str(l)) k = k + l print("总参数数量和:" + str(k)) 该层的结构:[64, 1] 该层参数和:64 该层的结构:[64, 16] 该层参数和:1024 该层的结构:[64] 该层参数和:64 该层的结构:[64] 该层参数和:64 该层的结构:[1, 48] 该层参数和:48 该层的结构:[1] 该层参数和:1 总参数数量和:1265 31+30+19+29 109

    参数列表

    input_size:x的特征维度 hidden_size:隐藏层的特征维度 num_layers:lstm隐层的层数,默认为1 bias:False则bih=0和bhh=0. 默认为True batch_first:True则输入输出的数据格式为 (batch, seq, feature) dropout:除最后一层,每一层的输出都进行dropout,默认为: 0 bidirectional:True则为双向lstm默认为False 输入:input, (h0, c0) 输出:output, (hn,cn) 输入数据格式: input(seq_len, batch, input_size) h0(num_layers * num_directions, batch, hidden_size) c0(num_layers * num_directions, batch, hidden_size)

    输出数据格式: output(seq_len, batch, hidden_size * num_directions) hn(num_layers * num_directions, batch, hidden_size) cn(num_layers * num_directions, batch, hidden_size)

    Pytorch里的LSTM单元接受的输入都必须是3维的张量(Tensors).每一维代表的意思不能弄错。

    第一维体现的是序列(sequence)结构,也就是序列的个数,用文章来说,就是每个句子的长度,因为是喂给网络模型,一般都设定为确定的长度,也就是我们喂给LSTM神经元的每个句子的长度,当然,如果是其他的带有带有序列形式的数据,则表示一个明确分割单位长度,

    例如是如果是股票数据内,这表示特定时间单位内,有多少条数据。这个参数也就是明确这个层中有多少个确定的单元来处理输入的数据。

    第二维度体现的是batch_size,也就是一次性喂给网络多少条句子,或者股票数据中的,一次性喂给模型多少是个时间单位的数据,具体到每个时刻,也就是一次性喂给特定时刻处理的单元的单词数或者该时刻应该喂给的股票数据的条数

    第三位体现的是输入的元素(elements of input),也就是,每个具体的单词用多少维向量来表示,或者股票数据中 每一个具体的时刻的采集多少具体的值,比如最低价,最高价,均价,5日均价,10均价,等等

    H0-Hn是什么意思呢?就是每个时刻中间神经元应该保存的这一时刻的根据输入和上一课的时候的中间状态值应该产生的本时刻的状态值,

    这个数据单元是起的作用就是记录这一时刻之前考虑到所有之前输入的状态值,形状应该是和特定时刻的输出一致

    c0-cn就是开关,决定每个神经元的隐藏状态值是否会影响的下一时刻的神经元的处理,形状应该和h0-hn一致。

    当然如果是双向,和多隐藏层还应该考虑方向和隐藏层的层数。

    df = pd.read_csv('./COVID-19_Italy.csv') df.head() Unnamed: 0countryEnglishNameconfirmedCountnowconfirmedCountcuredCountdeadCountdeadRatecuredRate00Italy00000.00.011Italy00000.00.022Italy00000.00.033Italy00000.00.044Italy00000.00.0 value = df['confirmedCount'].values[20:140] value array([ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 20, 117, 230, 283, 374, 528, 653, 888, 1128, 1694, 2036, 2502, 3089, 3927, 4636, 5883, 7375, 9172, 10283, 12462, 15113, 17660, 21270, 24938, 29022, 31506, 37178, 41035, 47021, 55218, 59514, 64378, 70545, 74386, 81129, 87275, 93051, 97689, 102106, 105792, 110574, 115895, 120281, 125016, 129481, 132810, 135893, 139887, 143626, 148217, 152860, 156673, 159516, 162488, 165155, 168941, 172434, 175925, 178972, 183957, 183957, 187327, 192994, 192994, 195351, 197675, 199414, 201505, 203591, 205463, 207428, 209328, 210717, 211938, 213013, 214457, 215858, 217185, 218268, 219070, 219814, 221216, 222104, 223096, 223885, 224760, 225435, 225886, 226699, 228006, 228658, 229327, 229858, 230158, 230555, 231139, 231732, 232248, 232664, 233019, 233197, 233515, 233836, 234013, 234531, 234801, 234998], dtype=int64) print(len(value)) x = [] y = [] seq = 3 for i in range(len(value)-seq-1): x.append(value[i:i+seq]) y.append(value[i+seq]) 118 train_x = (torch.tensor(x[0:90]).float()/100000.).reshape(-1, seq, 1) train_y = (torch.tensor(y[0:90]).float()/100000.).reshape(-1, 1) test_x = (torch.tensor(x[90:110]).float()/100000.).reshape(-1, seq, 1) test_y = (torch.tensor(y[90:110]).float()/100000.).reshape(-1, 1) print(test_y) tensor([[2.2210], [2.2310], [2.2389], [2.2476], [2.2543], [2.2589], [2.2670], [2.2801], [2.2866], [2.2933], [2.2986], [2.3016], [2.3056], [2.3114], [2.3173], [2.3225], [2.3266], [2.3302], [2.3320], [2.3352]]) # 模型训练 class LSTM(nn.Module): def __init__(self): super(LSTM, self).__init__() self.lstm = nn.LSTM(input_size=1, hidden_size=16, num_layers=1, batch_first=True) self.linear = nn.Linear(16 * seq, 1) def forward(self, x): x, (h, c) = self.lstm(x) x = x.reshape(-1, 16 * seq) x = self.linear(x) return x # 模型训练 model = LSTM() optimzer = torch.optim.Adam(model.parameters(), lr=0.005) loss_func = nn.MSELoss() model.train() for epoch in range(2000): output = model(train_x) loss = loss_func(output, train_y) optimzer.zero_grad() loss.backward() optimzer.step() if epoch % 20 == 0: tess_loss = loss_func(model(test_x), test_y) print("epoch:{}, train_loss:{}, test_loss:{}".format(epoch, loss, tess_loss)) epoch:0, train_loss:1.7632023096084595, test_loss:5.223465442657471 epoch:20, train_loss:0.2965010404586792, test_loss:0.12659545242786407 epoch:40, train_loss:0.02264116331934929, test_loss:0.05677216500043869 epoch:60, train_loss:0.006982944440096617, test_loss:0.0008666875073686242 epoch:80, train_loss:0.0018317289650440216, test_loss:0.003289521439000964 epoch:100, train_loss:0.0015176519518718123, test_loss:0.0015515109989792109 epoch:120, train_loss:0.001411953242495656, test_loss:0.0017424819525331259 epoch:140, train_loss:0.0013288009213283658, test_loss:0.002439487725496292 epoch:160, train_loss:0.0012708938447758555, test_loss:0.002725407015532255 epoch:180, train_loss:0.0012296135537326336, test_loss:0.003102920949459076 epoch:200, train_loss:0.0011986845638602972, test_loss:0.003337680594995618 epoch:220, train_loss:0.0011738166213035583, test_loss:0.0035245800390839577 epoch:240, train_loss:0.0011521923588588834, test_loss:0.0036384714767336845 epoch:260, train_loss:0.0011321220081299543, test_loss:0.003699743188917637 epoch:280, train_loss:0.001112664700485766, test_loss:0.0037150864955037832 epoch:300, train_loss:0.0010933543089777231, test_loss:0.0036975769326090813 epoch:320, train_loss:0.0010739663848653436, test_loss:0.0036563656758517027 epoch:340, train_loss:0.0010544119868427515, test_loss:0.003598833456635475 epoch:360, train_loss:0.0010346685303375125, test_loss:0.003530224785208702 epoch:380, train_loss:0.0010147254215553403, test_loss:0.003454281948506832 epoch:400, train_loss:0.0009945958154276013, test_loss:0.0033735043834894896 epoch:420, train_loss:0.0009742853580974042, test_loss:0.003289596876129508 epoch:440, train_loss:0.0009538100566715002, test_loss:0.003203604370355606 epoch:460, train_loss:0.0009331702603958547, test_loss:0.003116239095106721 epoch:480, train_loss:0.0009123761556111276, test_loss:0.0030278817284852266 epoch:500, train_loss:0.0008914396748878062, test_loss:0.002938809571787715 epoch:520, train_loss:0.0008703632047399879, test_loss:0.0028492698911577463 epoch:540, train_loss:0.0008491512853652239, test_loss:0.00275938818231225 epoch:560, train_loss:0.0008278173627331853, test_loss:0.0026692496612668037 epoch:580, train_loss:0.0008063663844950497, test_loss:0.002579005667939782 epoch:600, train_loss:0.0007848081295378506, test_loss:0.0024887227918952703 epoch:620, train_loss:0.0007631516200490296, test_loss:0.002398498123511672 epoch:640, train_loss:0.0007414081483148038, test_loss:0.002308483235538006 epoch:660, train_loss:0.0007195929065346718, test_loss:0.0022187510039657354 epoch:680, train_loss:0.0006977192242629826, test_loss:0.0021294010803103447 epoch:700, train_loss:0.0006758036906830966, test_loss:0.0020405412651598454 epoch:720, train_loss:0.0006538643501698971, test_loss:0.0019522873917594552 epoch:740, train_loss:0.0006319198873825371, test_loss:0.0018647522665560246 epoch:760, train_loss:0.0006099938764236867, test_loss:0.0017780527705326676 epoch:780, train_loss:0.0005881070392206311, test_loss:0.001692296122200787 epoch:800, train_loss:0.000566287839319557, test_loss:0.0016075713792815804 epoch:820, train_loss:0.0005445620627142489, test_loss:0.0015240126522257924 epoch:840, train_loss:0.000522959919180721, test_loss:0.0014417333295568824 epoch:860, train_loss:0.0005015107453800738, test_loss:0.0013608544832095504 epoch:880, train_loss:0.00048024847637861967, test_loss:0.0012814635410904884 epoch:900, train_loss:0.0004592059995047748, test_loss:0.0012036816915497184 epoch:920, train_loss:0.0004384215862955898, test_loss:0.0011276379227638245 epoch:940, train_loss:0.0004179326933808625, test_loss:0.0010534359607845545 epoch:960, train_loss:0.0003977797459810972, test_loss:0.0009811957133933902 epoch:980, train_loss:0.0003779999096877873, test_loss:0.0009110327227972448 epoch:1000, train_loss:0.00035864009987562895, test_loss:0.000843059562612325 epoch:1020, train_loss:0.00033973553217947483, test_loss:0.0007773857214488089 epoch:1040, train_loss:0.00032133201602846384, test_loss:0.0007141505484469235 epoch:1060, train_loss:0.0003034717810805887, test_loss:0.0006534302374348044 epoch:1080, train_loss:0.00028619240038096905, test_loss:0.0005953567451797426 epoch:1100, train_loss:0.00026953473570756614, test_loss:0.0005400101072154939 epoch:1120, train_loss:0.0002535344392526895, test_loss:0.00048749061534181237 epoch:1140, train_loss:0.00023822428192943335, test_loss:0.00043787891627289355 epoch:1160, train_loss:0.00022363335301633924, test_loss:0.0003912308602593839 epoch:1180, train_loss:0.00020978778775315732, test_loss:0.0003475920238997787 epoch:1200, train_loss:0.00019670836627483368, test_loss:0.00030700559727847576 epoch:1220, train_loss:0.00018440828716848046, test_loss:0.0002694650029297918 epoch:1240, train_loss:0.0001728977804305032, test_loss:0.00023496514768339694 epoch:1260, train_loss:0.00016217738448176533, test_loss:0.0002034622011706233 epoch:1280, train_loss:0.00015224415983539075, test_loss:0.00017491883772891015 epoch:1300, train_loss:0.00014308829850051552, test_loss:0.00014922290574759245 epoch:1320, train_loss:0.00013469379337038845, test_loss:0.00012628933473024517 epoch:1340, train_loss:0.00012703817628789693, test_loss:0.0001059926871675998 epoch:1360, train_loss:0.00012009469355689362, test_loss:8.818962669465691e-05 epoch:1380, train_loss:0.00011383087985450402, test_loss:7.271893991855904e-05 epoch:1400, train_loss:0.0001082118833437562, test_loss:5.941572817391716e-05 epoch:1420, train_loss:0.00010319782450096682, test_loss:4.810726750292815e-05 epoch:1440, train_loss:9.874825627775863e-05, test_loss:3.86049687222112e-05 epoch:1460, train_loss:9.481970482738689e-05, test_loss:3.0730687285540625e-05 epoch:1480, train_loss:9.136966400546953e-05, test_loss:2.4305656552314758e-05 epoch:1500, train_loss:8.835419430397451e-05, test_loss:1.9148759747622535e-05 epoch:1520, train_loss:8.573079685447738e-05, test_loss:1.5098817129910458e-05 epoch:1540, train_loss:8.345866081072018e-05, test_loss:1.199172947963234e-05 epoch:1560, train_loss:8.149856876116246e-05, test_loss:9.681958545115776e-06 epoch:1580, train_loss:7.981352973729372e-05, test_loss:8.036873623495921e-06 epoch:1600, train_loss:7.836938311811537e-05, test_loss:6.933561053301673e-06 epoch:1620, train_loss:7.713426020927727e-05, test_loss:6.265192041610135e-06 epoch:1640, train_loss:7.607969018863514e-05, test_loss:5.936094112257706e-06 epoch:1660, train_loss:7.5179836130701e-05, test_loss:5.865499588253442e-06 epoch:1680, train_loss:7.441190973622724e-05, test_loss:5.9840194808202796e-06 epoch:1700, train_loss:7.375545828836039e-05, test_loss:6.233521162357647e-06 epoch:1720, train_loss:7.319300493691117e-05, test_loss:6.566989213752095e-06 epoch:1740, train_loss:7.270894275279716e-05, test_loss:6.945073437236715e-06 epoch:1760, train_loss:7.229032053146511e-05, test_loss:7.33830847821082e-06 epoch:1780, train_loss:7.192560588009655e-05, test_loss:7.723288035776932e-06 epoch:1800, train_loss:7.160565291997045e-05, test_loss:8.082806743914261e-06 epoch:1820, train_loss:7.132247992558405e-05, test_loss:8.406545930483844e-06 epoch:1840, train_loss:7.106920384103432e-05, test_loss:8.6846148406039e-06 epoch:1860, train_loss:7.084051321726292e-05, test_loss:8.915197213354986e-06 epoch:1880, train_loss:7.063188240863383e-05, test_loss:9.095112545765005e-06 epoch:1900, train_loss:7.043959340080619e-05, test_loss:9.22342951525934e-06 epoch:1920, train_loss:7.026109233265743e-05, test_loss:9.305442290497012e-06 epoch:1940, train_loss:7.009309774730355e-05, test_loss:9.340597898699343e-06 epoch:1960, train_loss:6.99344091117382e-05, test_loss:9.336043149232864e-06 epoch:1980, train_loss:6.978306191740558e-05, test_loss:9.294308256357908e-06 model.eval() prediction = list((model(train_x).data.reshape(-1))*100000) + list((model(test_x).data.reshape(-1))*100000) plt.plot(value[3:], label='True Value') plt.plot(prediction[:91], label='LSTM fit') plt.plot(np.arange(90, 110, 1), prediction[90:], label='LSTM pred') print(len(value[3:])) print(len(prediction[90:])) plt.legend(loc='best') plt.title('Cumulative infections prediction(Italy)') plt.xlabel('Day') plt.ylabel('Cumulative Cases') plt.show() 115 20

    Processed: 0.016, SQL: 9