Skip to content

Commit

Permalink
add site_id
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Sep 12, 2023
1 parent 7cbf992 commit 10e720b
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions hivemind_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class HiveMindClientConnection:
blacklist: List[str] = field(default_factory=list) # list of ovos message_type to never be sent to this client
allowed_types: List[str] = field(default_factory=list) # list of ovos message_type to allow to be sent from this client
binarize: bool = False
site_id: str = "unknown"

@property
def peer(self) -> str:
Expand Down Expand Up @@ -191,11 +192,12 @@ def handle_internal_mycroft(self, message: str):

new_sess = Session.from_message(message)
for peer, client in self.clients.items():
# ovos-core decides the contents of the Session,
# ovos-core decides the runtime contents of the Session,
# let's sync any internal changes
if new_sess.session_id == client.sess.session_id:
LOG.debug(f"syncing session from ovos with {peer}")
client.sess = Session.from_message(message)
client.sess.site_id = client.site_id

if peer in target_peers:
# forward internal messages to clients if they are the target
Expand Down Expand Up @@ -231,13 +233,19 @@ def bind(self, websocket, bus):
self.internal_protocol = HiveMindListenerInternalProtocol(bus)
self.internal_protocol.register_bus_handlers()

def get_bus(self, client: HiveMindClientConnection):
# allow subclasses to use dedicated bus per client
return self.internal_protocol.bus

def handle_new_client(self, client: HiveMindClientConnection):
LOG.debug(f"new client: {client.peer}")
self.clients[client.peer] = client
message = Message("hive.client.connect",
{"ip": client.ip, "session_id": client.sess.session_id},
{"source": client.peer})
self.internal_protocol.bus.emit(message)

bus = self.get_bus(client)
bus.emit(message)

min_version = ProtocolVersion.ONE if client.crypto_key is None and self.require_crypto \
else ProtocolVersion.ZERO
Expand Down Expand Up @@ -276,21 +284,24 @@ def handle_client_disconnected(self, client: HiveMindClientConnection):
message = Message("hive.client.disconnect",
{"ip": client.ip},
{"source": client.peer, "session": client.sess.serialize()})
self.internal_protocol.bus.emit(message)
bus = self.get_bus(client)
bus.emit(message)

def handle_invalid_key_connected(self, client: HiveMindClientConnection):
LOG.error("Client provided an invalid api key")
message = Message("hive.client.connection.error",
{"error": "invalid api key", "peer": client.peer},
{"source": client.peer})
self.internal_protocol.bus.emit(message)
bus = self.get_bus(client)
bus.emit(message)

def handle_invalid_protocol_version(self, client: HiveMindClientConnection):
LOG.error("Client does not satisfy protocol requirements")
message = Message("hive.client.connection.error",
{"error": "protocol error", "peer": client.peer},
{"source": client.peer})
self.internal_protocol.bus.emit(message)
bus = self.get_bus(client)
bus.emit(message)

def handle_message(self, message: HiveMessage, client: HiveMindClientConnection):
"""
Expand Down Expand Up @@ -340,6 +351,8 @@ def handle_handshake_message(self, message: HiveMessage,
client: HiveMindClientConnection):
LOG.debug("handshake received, generating session key")
payload = message.payload
if "site_id" in payload:
client.sess.site_id = client.site_id = payload["site_id"]
if "pubkey" in payload and client.handshake is not None:
pub = payload.pop("pubkey")
payload["envelope"] = client.handshake.generate_handshake(pub)
Expand Down Expand Up @@ -438,7 +451,8 @@ def handle_propagate_message(self, message: HiveMessage,
{"destination": "hive",
"source": self.peer,
"session": client.sess.serialize()})
self.internal_protocol.bus.emit(message)
bus = self.get_bus(client)
bus.emit(message)

def handle_escalate_message(self, message: HiveMessage,
client: HiveMindClientConnection):
Expand All @@ -461,7 +475,8 @@ def handle_escalate_message(self, message: HiveMessage,
{"destination": "hive",
"source": self.peer,
"session": client.sess.serialize()})
self.internal_protocol.bus.emit(message)
bus = self.get_bus(client)
bus.emit(message)

# HiveMind mycroft bus messages - from slave -> master
def update_slave_session(self, message: Message, client: HiveMindClientConnection):
Expand Down Expand Up @@ -498,7 +513,8 @@ def handle_inject_mycroft_msg(self, message: Message, client: HiveMindClientConn
# validate slave session
message = self.update_slave_session(message, client)

self.internal_protocol.bus.emit(message)
bus = self.get_bus(client)
bus.emit(message)

if self.mycroft_bus_callback:
self.mycroft_bus_callback(message)
Expand Down

0 comments on commit 10e720b

Please sign in to comment.