#candle #pytorch #lstm

烛-birnn

使用烛实现 PyTorch LSTM 和双向 LSTM

4 个版本

0.2.2 2024 年 8 月 17 日
0.2.1 2024 年 8 月 13 日
0.2.0 2024 年 8 月 12 日
0.1.0 2024 年 8 月 11 日

#313机器学习

Download history 207/week @ 2024-08-10 141/week @ 2024-08-17

每月 348 次下载

自定义许可证

42KB
299

包含 (ZIP 文件, 29KB) bi_lstm_test.pt,(ZIP 文件, 16KB) lstm_test.pt

烛 BiRNN

使用烛实现 PyTorch LSTM 推理,包括双向 LSTM 推理的实现。

测试数据

  1. lstm_test.pt:使用 PyTorch 演示程序生成的结果。代码如下

    import torch
    import torch.nn as nn
    
    rnn = nn.LSTM(10, 20, 1)
    input = torch.randn(5, 3, 10)
    output, (hn, cn) = rnn(input)
    
    state_dict = rnn.state_dict()
    state_dict['input'] = input
    state_dict['output'] = output
    state_dict['hn'] = hn
    state_dict['cn'] = cn
    torch.save(state_dict, "lstm_test.pt")
    
  2. bi_lstm_test.pt:使用 PyTorch 演示程序生成的结果。代码如下

    import torch
    import torch.nn as nn
    
    rnn = nn.LSTM(10, 20, 1, bidirectional=True)
    input = torch.randn(5, 3, 10)
    output, (hn, cn) = rnn(input)
    
    state_dict = rnn.state_dict()
    state_dict['input'] = input
    state_dict['output'] = output
    state_dict['hn'] = hn
    state_dict['cn'] = cn
    torch.save(state_dict, "bi_lstm_test.pt")
    

依赖

~11MB
~218K SLoC