4 个版本
0.2.2 | 2024 年 8 月 17 日 |
---|---|
0.2.1 | 2024 年 8 月 13 日 |
0.2.0 | 2024 年 8 月 12 日 |
0.1.0 | 2024 年 8 月 11 日 |
#313 在 机器学习 中
每月 348 次下载
42KB
299 行
包含 (ZIP 文件, 29KB) bi_lstm_test.pt,(ZIP 文件, 16KB) lstm_test.pt
烛 BiRNN
使用烛实现 PyTorch LSTM 推理,包括双向 LSTM 推理的实现。
测试数据
-
lstm_test.pt:使用 PyTorch 演示程序生成的结果。代码如下
import torch import torch.nn as nn rnn = nn.LSTM(10, 20, 1) input = torch.randn(5, 3, 10) output, (hn, cn) = rnn(input) state_dict = rnn.state_dict() state_dict['input'] = input state_dict['output'] = output state_dict['hn'] = hn state_dict['cn'] = cn torch.save(state_dict, "lstm_test.pt")
-
bi_lstm_test.pt:使用 PyTorch 演示程序生成的结果。代码如下
import torch import torch.nn as nn rnn = nn.LSTM(10, 20, 1, bidirectional=True) input = torch.randn(5, 3, 10) output, (hn, cn) = rnn(input) state_dict = rnn.state_dict() state_dict['input'] = input state_dict['output'] = output state_dict['hn'] = hn state_dict['cn'] = cn torch.save(state_dict, "bi_lstm_test.pt")
依赖
~11MB
~218K SLoC