PyTorch control flow¶
if statement¶
We don’t support any control flow, so For the following code, we only put the if branch that is executed during tracing into the graph.
if self.training:
...
else:
...
The consequence is that model training/validation will use exactly the same code path.
if expression¶
Some torch operations use if expression to select different parameters, for example
torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=self.is_causal
)
To support that, we provide a limited if expression support,
by converting if expression to a function call.
For example:
We will convert
x = a if self.training else b
to
x = nnscaler.runtime.function.ifexpr(self.training, a, b)
This trick is not free. It will introduce two side effects:
Short-circuit evaluation is not supported. Both branches will be evaluated, so you must make sure that both branches are valid, and have no side effect. To reduce the side effect, we will check true expr/false expr, and requires both don’t contain function calls. so the following code will not be converted:
x = f(a) if self.training else b
We will convert
ifexpression only if the condition isself.training. So if a non-module class has atrainingattribute, theifexpression in its member functions will also be converted if its condition isself.training.
Please note you can always use register_op to define a custom op to handle the if expression.
For example, you can convert the above code to:
import nnscaler
import torch
@nnscaler.register_op('?, ? -> ?')
def get_dropout(training, dropout):
return dropout if training else 0
torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None,
dropout_p=get_dropout(self, self.dropout),
is_causal=self.is_causal
)
self.training as a parameter¶
If you use self.training as a parameter, it is well supported.
For example:
torch.nn.functional.dropout(x, 0.1, self.training)
# the generated code will be exactly the same as the original code:
# torch.nn.functional.dropout(x, 0.1, self.training)
But be careful, if you use self.training in a boolean operation,
the generated code may be not as you expected, because
We don’t trace bool operations.
Boolean operations are short-circuit evaluated, so only one expression will be kept in generated code.
For example:
torch.nn.functional.dropout(x, 0.1, global_setting.enable_dropout or self.training)
# if global_setting.enable_dropout is True, the generated code will be
# torch.nn.functional.dropout(x, 0.1, True)
# if global_setting.enable_dropout is False, the generated code will be
# torch.nn.functional.dropout(x, 0.1, self.training)