-
Notifications
You must be signed in to change notification settings - Fork 0
/
LibDiamond.sol
205 lines (185 loc) · 9.39 KB
/
LibDiamond.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
// SPDX-License-Identifier: CC0-1.0
pragma solidity ^0.8.0;
/******************************************************************************\
* Author: Nick Mudge <nick@perfectabstractions.com>, Twitter/Github: @mudgen
* EIP-2535 Diamonds
/******************************************************************************/
import { IDiamond } from "../interfaces/IDiamond.sol";
import { IDiamondCut } from "../interfaces/IDiamondCut.sol";
// Remember to add the loupe functions from DiamondLoupeFacet to the diamond.
// The loupe functions are required by the EIP2535 Diamonds standard
error NoSelectorsGivenToAdd();
error NotContractOwner(address _user, address _contractOwner);
error NoSelectorsProvidedForFacetForCut(address _facetAddress);
error CannotAddSelectorsToZeroAddress(bytes4[] _selectors);
error NoBytecodeAtAddress(address _contractAddress, string _message);
error IncorrectFacetCutAction(uint8 _action);
error CannotAddFunctionToDiamondThatAlreadyExists(bytes4 _selector);
error CannotReplaceFunctionsFromFacetWithZeroAddress(bytes4[] _selectors);
error CannotReplaceImmutableFunction(bytes4 _selector);
error CannotReplaceFunctionWithTheSameFunctionFromTheSameFacet(bytes4 _selector);
error CannotReplaceFunctionThatDoesNotExists(bytes4 _selector);
error RemoveFacetAddressMustBeZeroAddress(address _facetAddress);
error CannotRemoveFunctionThatDoesNotExist(bytes4 _selector);
error CannotRemoveImmutableFunction(bytes4 _selector);
error InitializationFunctionReverted(address _initializationContractAddress, bytes _calldata);
library LibDiamond {
bytes32 constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage");
struct FacetAddressAndSelectorPosition {
address facetAddress;
uint16 selectorPosition;
}
struct DiamondStorage {
// function selector => facet address and selector position in selectors array
mapping(bytes4 => FacetAddressAndSelectorPosition) facetAddressAndSelectorPosition;
bytes4[] selectors;
mapping(bytes4 => bool) supportedInterfaces;
// owner of the contract
address contractOwner;
}
function diamondStorage() internal pure returns (DiamondStorage storage ds) {
bytes32 position = DIAMOND_STORAGE_POSITION;
assembly {
ds.slot := position
}
}
event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);
function setContractOwner(address _newOwner) internal {
DiamondStorage storage ds = diamondStorage();
address previousOwner = ds.contractOwner;
ds.contractOwner = _newOwner;
emit OwnershipTransferred(previousOwner, _newOwner);
}
function contractOwner() internal view returns (address contractOwner_) {
contractOwner_ = diamondStorage().contractOwner;
}
function enforceIsContractOwner() internal view {
if(msg.sender != diamondStorage().contractOwner) {
revert NotContractOwner(msg.sender, diamondStorage().contractOwner);
}
}
event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata);
// Internal function version of diamondCut
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
bytes4[] memory functionSelectors = _diamondCut[facetIndex].functionSelectors;
address facetAddress = _diamondCut[facetIndex].facetAddress;
if(functionSelectors.length == 0) {
revert NoSelectorsProvidedForFacetForCut(facetAddress);
}
IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action;
if (action == IDiamond.FacetCutAction.Add) {
addFunctions(facetAddress, functionSelectors);
} else if (action == IDiamond.FacetCutAction.Replace) {
replaceFunctions(facetAddress, functionSelectors);
} else if (action == IDiamond.FacetCutAction.Remove) {
removeFunctions(facetAddress, functionSelectors);
} else {
revert IncorrectFacetCutAction(uint8(action));
}
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}
function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
if(_facetAddress == address(0)) {
revert CannotAddSelectorsToZeroAddress(_functionSelectors);
}
DiamondStorage storage ds = diamondStorage();
uint16 selectorCount = uint16(ds.selectors.length);
enforceHasContractCode(_facetAddress, "LibDiamondCut: Add facet has no code");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
if(oldFacetAddress != address(0)) {
revert CannotAddFunctionToDiamondThatAlreadyExists(selector);
}
ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition(_facetAddress, selectorCount);
ds.selectors.push(selector);
selectorCount++;
}
}
function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
DiamondStorage storage ds = diamondStorage();
if(_facetAddress == address(0)) {
revert CannotReplaceFunctionsFromFacetWithZeroAddress(_functionSelectors);
}
enforceHasContractCode(_facetAddress, "LibDiamondCut: Replace facet has no code");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
// can't replace immutable functions -- functions defined directly in the diamond in this case
if(oldFacetAddress == address(this)) {
revert CannotReplaceImmutableFunction(selector);
}
if(oldFacetAddress == _facetAddress) {
revert CannotReplaceFunctionWithTheSameFunctionFromTheSameFacet(selector);
}
if(oldFacetAddress == address(0)) {
revert CannotReplaceFunctionThatDoesNotExists(selector);
}
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _facetAddress;
}
}
function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
DiamondStorage storage ds = diamondStorage();
uint256 selectorCount = ds.selectors.length;
if(_facetAddress != address(0)) {
revert RemoveFacetAddressMustBeZeroAddress(_facetAddress);
}
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds.facetAddressAndSelectorPosition[selector];
if(oldFacetAddressAndSelectorPosition.facetAddress == address(0)) {
revert CannotRemoveFunctionThatDoesNotExist(selector);
}
// can't remove immutable functions -- functions defined directly in the diamond
if(oldFacetAddressAndSelectorPosition.facetAddress == address(this)) {
revert CannotRemoveImmutableFunction(selector);
}
// replace selector with last selector
selectorCount--;
if (oldFacetAddressAndSelectorPosition.selectorPosition != selectorCount) {
bytes4 lastSelector = ds.selectors[selectorCount];
ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector;
ds.facetAddressAndSelectorPosition[lastSelector].selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition;
}
// delete last selector
ds.selectors.pop();
delete ds.facetAddressAndSelectorPosition[selector];
}
}
function initializeDiamondCut(address _init, bytes memory _calldata) internal {
if (_init == address(0)) {
return;
}
enforceHasContractCode(_init, "LibDiamondCut: _init address has no code");
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up error
/// @solidity memory-safe-assembly
assembly {
let returndata_size := mload(error)
revert(add(32, error), returndata_size)
}
} else {
revert InitializationFunctionReverted(_init, _calldata);
}
}
}
function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
uint256 contractSize;
assembly {
contractSize := extcodesize(_contract)
}
if(contractSize == 0) {
revert NoBytecodeAtAddress(_contract, _errorMessage);
}
}
}