-
Notifications
You must be signed in to change notification settings - Fork 30
/
mqtt.py
119 lines (100 loc) · 5.66 KB
/
mqtt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import uvicorn
import json
import base64
import logging
import threading
import time
import cv2
import paho.mqtt.client as mqtt
import numpy as np
from streamer import Streamer
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
class MQTT():
def __init__(self, config, doods, metrics_server_config=None):
self.config = config
self.doods = doods
self.metrics_server_config = metrics_server_config
self.mqtt_client = mqtt.Client()
# Borrow the uvicorn logger because it's pretty.
self.logger = logging.getLogger("doods.mqtt")
def stream(self, mqtt_detect_request: str = '{}'):
streamer = None
try:
# Run the stream detector and return the results.
streamer = Streamer(self.doods).start_stream(mqtt_detect_request)
for detect_response in streamer:
# If separate_detections, iterate over each detection and process it separately
if mqtt_detect_request.separate_detections:
#If we're going to be cropping, do this processing only once (rather than for each detection)
if mqtt_detect_request.image and mqtt_detect_request.crop:
detect_image_bytes = np.frombuffer(detect_response.image, dtype=np.uint8)
detect_image = cv2.imdecode(detect_image_bytes, cv2.IMREAD_COLOR)
di_height, di_width = detect_image.shape[:2]
for detection in detect_response.detections:
# If an image was requested
if mqtt_detect_request.image:
# Crop image to detection box if requested
if mqtt_detect_request.crop:
cropped_image = detect_image[
int(detection.top*di_height):int(detection.bottom*di_height),
int(detection.left*di_width):int(detection.right*di_width)]
mqtt_image = cv2.imencode(mqtt_detect_request.image, cropped_image)[1].tostring()
else:
mqtt_image = detect_response.image
# For binary images, publish the image to its own topic
if mqtt_detect_request.binary_images:
self.mqtt_client.publish(
f"doods/image/{mqtt_detect_request.id}{'' if detection.region_id is None else '/'+detection.region_id}/{detection.label or 'object'}",
payload=mqtt_image, qos=0, retain=False)
# Otherwise add base64-encoded image to the detection
else:
detection.image = base64.b64encode(mqtt_image).decode('utf-8')
self.mqtt_client.publish(
f"doods/detect/{mqtt_detect_request.id}{'' if detection.region_id is None else '/'+detection.region_id}/{detection.label or 'object'}",
payload=json.dumps(detection.asdict(include_none=False)), qos=0, retain=False)
# Otherwise, publish the collected detections together
else:
# If an image was requested
if mqtt_detect_request.image:
# If binary_images, move the image from the response and publish it to a separate topic
if mqtt_detect_request.binary_images:
mqtt_image = detect_response.image
detect_response.image = None
self.mqtt_client.publish(
f"doods/image/{mqtt_detect_request.id}",
payload=detect_response.image, qos=0, retain=False)
# Otherwise, inlcude the base64-encoded image in the response
else:
detect_response.image = base64.b64encode(detect_response.image).decode('utf-8')
self.mqtt_client.publish(
f"doods/detect/{mqtt_detect_request.id}",
payload=json.dumps(detect_response.asdict(include_none=False)), qos=0, retain=False)
finally:
try:
if streamer:
streamer.send(True) # Stop the streamer
except StopIteration:
pass
def on_message(self, client, userdata, msg):
print(msg)
def metrics_server(self, config):
app = FastAPI()
self.instrumentator = Instrumentator(
excluded_handlers=["/metrics"],
)
self.instrumentator.instrument(app).expose(app)
uvicorn.run(app, host=config.host, port=config.port, log_config=None)
def run(self):
if (self.config.broker.user):
self.mqtt_client.username_pw_set(self.config.broker.user, self.config.broker.password)
self.mqtt_client.connect(self.config.broker.host, self.config.broker.port, 60)
for request in self.config.requests:
threading.Thread(target=self.stream, args=(request,)).start()
if (self.config.api):
self.mqtt_client.subscribe(self.config.api.request_topic)
self.mqtt_client.on_message = self.on_message
self.logger.info(f'listening on mqtt topic %s for requests', self.config.api.request_topic)
if self.config.metrics:
self.logger.info("starting metrics server")
self.metrics_server(self.metrics_server_config)