pytorch 手写GRU

    技术2025-08-30  5

    刚开始想直接从pytorch源码来整,结果你瞧瞧源码写的都是啥:

    class GRU(RNNBase): def __init__(self, *args, **kwargs): super(GRU, self).__init__('GRU', *args, **kwargs) @torch._jit_internal._overload_method # noqa: F811 def forward(self, input, hx=None): # noqa: F811 # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] pass @torch._jit_internal._overload_method # noqa: F811 def forward(self, input, hx=None): # noqa: F811 # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor] pass def forward(self, input, hx=None): # noqa: F811 orig_input = input # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): input, batch_sizes, sorted_indices, unsorted_indices = input max_batch_size = batch_sizes[0] max_batch_size = int(max_batch_size) else: batch_sizes = None max_batch_size = input.size(0) if self.batch_first else input.size(1) sorted_indices = None unsorted_indices = None if hx is None: num_directions = 2 if self.bidirectional else 1 hx = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=input.dtype, device=input.device) else: # Each batch of the hidden state should match the input sequence that # the user believes he/she is passing in. hx = self.permute_hidden(hx, sorted_indices) self.check_forward_args(input, hx, batch_sizes) if batch_sizes is None: result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional, self.batch_first) else: result = _VF.gru(input, batch_sizes, hx, self._flat_weights, self.bias, self.num_layers, self.dropout, self.training, self.bidirectional) output = result[0] hidden = result[1] # xxx: isinstance check needs to be in conditional for TorchScript to compile if isinstance(orig_input, PackedSequence): output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) return output_packed, self.permute_hidden(hidden, unsorted_indices) else: return output, self.permute_hidden(hidden, unsorted_indices)

    https://discuss.pytorch.org/t/where-to-find-torch-c-variablefunctions-module/41305/5

    https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp 嗯对,源码是用C++写的,fine。那我就直接从写一个吧,反正也不难,好了正片开始了:

    Processed: 0.012, SQL: 9