Skip to content

Commit

Permalink
fix: broadcast시 nparray 인코딩
Browse files Browse the repository at this point in the history
- 이미지 업로드시 location name NONE
- Server 구동시 환경에 따라 CPU/GPU 설정 및 로그 출력
  • Loading branch information
sukkyun2 committed Aug 19, 2024
1 parent cf9de63 commit 353097d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
6 changes: 3 additions & 3 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from app.config import settings
from app.connection_manager import ConnectionManager
from app.history import async_save_history
from model.detect import detect, track, estimate_distance, DetectionResult
from model.detect import detect, estimate_distance, DetectionResult
from model.video_recorder import VideoRecorder

app = FastAPI()
Expand All @@ -36,7 +36,7 @@ async def detect_image(file: UploadFile = File(...)) -> ApiResponse:
return ApiResponse.bad_request(str(err))

result = detect(np.array(img))
await async_save_history(result)
await async_save_history(result, "NONE")

return ApiResponse.ok()

Expand Down Expand Up @@ -65,7 +65,7 @@ async def websocket_publisher(websocket: WebSocket, location_name: str):
if video_recorder.is_recording:
video_recorder.record_frame(result.plot_image)

await manager.broadcast(location_name, result.plot_image.tobytes())
await manager.broadcast(location_name, result.get_encoded_nparr().tobytes())
except WebSocketDisconnect:
manager.disconnect(location_name)
print("Publisher disconnected")
Expand Down
6 changes: 6 additions & 0 deletions model/detect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from typing import List, Dict, Tuple

import cv2
import numpy as np
import torch
from PIL import Image as img
from numpy import ndarray
from ultralytics import YOLO
Expand All @@ -11,6 +13,10 @@

model = YOLO(settings.yolo_weight_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
print(f"Model run on the {device}")

tracked_objects: Dict[int, TrackedObject] = {}


Expand Down
6 changes: 6 additions & 0 deletions model/schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Optional

import cv2
import numpy as np
from PIL import Image as img
from numpy import ndarray


class TrackedObject:
Expand Down Expand Up @@ -37,3 +39,7 @@ def __init__(self, plot_image: np.ndarray, detections: List[Detection]):

def get_image(self) -> img:
return img.fromarray(self.plot_image[..., ::-1])

def get_encoded_nparr(self) -> ndarray:
_, encoded_nparr = cv2.imencode('.jpg', self.plot_image)
return encoded_nparr

0 comments on commit 353097d

Please sign in to comment.