-
Notifications
You must be signed in to change notification settings - Fork 116
/
torchvision_schema.py
59 lines (42 loc) · 2.41 KB
/
torchvision_schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from abc import abstractmethod
from ...operators.torch.base import OperatorConverter
class TorchVisionPsRoiAlignSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)'''
pass
class TorchVisionRoiAlignSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> (Tensor)'''
pass
class TorchVisionPsRoiPoolSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)'''
pass
class TorchVisionDeformConv2dSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor)'''
pass
class TorchVisionInterpolateBilinear2dAaSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::_interpolate_bilinear2d_aa(Tensor input, int[] output_size, bool align_corners) -> (Tensor)'''
pass
class TorchVisionInterpolateBicubic2dAaSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::_interpolate_bicubic2d_aa(Tensor input, int[] output_size, bool align_corners) -> (Tensor)'''
pass
class TorchVisionNmsSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> (Tensor)'''
pass
class TorchVisionRoiPoolSchema(OperatorConverter):
@abstractmethod
def parse(self, node, attrs, args, graph_converter):
'''torchvision::roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)'''
pass