Skip to content

Commit

Permalink
Merge pull request #108 from GreenTeaProgrammers/feature/machine-lear…
Browse files Browse the repository at this point in the history
…ning/refactor

refactor(ml)
  • Loading branch information
lovelovetrb authored Feb 18, 2024
2 parents 8d105df + d808296 commit 02d14fd
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
load_cascade,
)

from face_detect_model.gcp_util import get_bucket, get_blobs, init_client

load_dotenv("secrets/.env")

logging.basicConfig(
Expand Down Expand Up @@ -76,46 +78,6 @@ def init_save_dir(save_dir_path: str):
os.remove(file_path)


def init_client():
# NOTE: gcloud auth application-default loginにて事前に認証
credential = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
PROJECT_ID = os.environ.get("PROJECT_ID")

client = gcs.Client(PROJECT_ID, credentials=credential)
if client is None:
logger.error("Failed to initialize client.")
exit(1)
else:
return client


def get_bucket(client: gcs.Client):
# NOTE: 環境変数からバケット名を取得
BUCKET_NAME = os.environ.get("BUCKET_NAME")
bucket = client.bucket(BUCKET_NAME)

if bucket.exists():
return bucket
else:
logger.error(f"Failed to {BUCKET_NAME} does not exist.")
exit(1)


def get_blobs(bucket: Bucket, blob_name: str):
blobs = list(bucket.list_blobs(prefix=blob_name))

# blobsの中身に対するエラーハンドリング
try:
if len(blobs) == 0: # 最初の要素がない場合、イテレータは空
logger.error(f"No blobs found with prefix '{blob_name}' in the bucket.")
exit(1)
else:
return blobs
except Exception as e:
logger.error(f"Failed to get blobs from '{blob_name}' due to an error: {e}")
exit(1)


def save_face_image_to_local(face: np.ndarray, save_dir: str, save_file_name: str):
"""クリップされた顔画像を保存する"""
os.makedirs(save_dir, exist_ok=True)
Expand Down Expand Up @@ -156,7 +118,8 @@ def detect_face_and_clip(args: argparse.Namespace, config: dict):
# GCSとの接続
if args.env == "remote":
client = init_client()
bucket = get_bucket(client)
BUCKET_NAME = os.environ.get("BUCKET_NAME")
bucket = get_bucket(client, BUCKET_NAME)

# Haar Cascadeの読み込み
face_cascade = load_cascade(face_cascade_path)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cv2
import os


def load_cascade(cascade_path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from torchvision import transforms

from face_detect_model.util import (
init_client,
get_bucket,
get_blobs,
load_image_from_remote,
)

from face_detect_model.gcp_util import get_bucket, get_blobs

# TODO: GCSに関する処理を別ファイルに切り出す


Expand Down
35 changes: 35 additions & 0 deletions machine_learning/src/face_detect_model/gcp_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import google.cloud.storage as gcs
from google.cloud.storage import Bucket
import os


def init_client():
# NOTE: gcloud auth application-default loginにて事前に認証
PROJECT_ID = os.environ.get("PROJECT_ID")

client = gcs.Client(PROJECT_ID)
if client is None:
raise RuntimeError("Failed to initialize client.")
else:
return client


def get_bucket(client: gcs.Client, bucket_name: str):
bucket = client.bucket(bucket_name)

if bucket.exists():
return bucket
else:
raise ValueError(f"Failed to {bucket_name} does not exist.")


def get_blobs(bucket: Bucket, blob_name: str):
# blobsの中身に対するエラーハンドリング
try:
blobs = list(bucket.list_blobs(prefix=blob_name))
if len(blobs) == 0: # 最初の要素がない場合、イテレータは空
raise ValueError(f"No blobs found with prefix '{blob_name}' in the bucket.")
else:
return blobs
except Exception as e:
raise ValueError(f"Failed to get blobs from '{blob_name}' due to an error: {e}")
3 changes: 2 additions & 1 deletion machine_learning/src/face_detect_model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from face_detect_model.data.faceDetectDataset import FaceDetectDataset
from face_detect_model.model.faceDetectModel import FaceDetectModel
from face_detect_model.trainer import Trainer
from face_detect_model.util import logger, init_client
from face_detect_model.util import logger
from face_detect_model.gcp_util import init_client
from dotenv import load_dotenv

load_dotenv("secrets/.env")
Expand Down
4 changes: 4 additions & 0 deletions machine_learning/src/face_detect_model/pred.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def main():
gray = cv2.cvtColor(captureBuffer, cv2.COLOR_BGR2GRAY)
pred_child_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
return pred_child_ids
3 changes: 2 additions & 1 deletion machine_learning/src/face_detect_model/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from face_detect_model.data.faceDetectDataset import FaceDetectDataset
from face_detect_model.model.faceDetectModel import FaceDetectModel
from face_detect_model.util import logger, get_bucket, save_model_to_gcs
from face_detect_model.util import logger, save_model_to_gcs
from face_detect_model.gcp_util import get_bucket


class Trainer:
Expand Down
34 changes: 0 additions & 34 deletions machine_learning/src/face_detect_model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import random
import torch
import google.cloud.storage as gcs
from google.cloud.storage import Bucket
import os
import numpy as np
Expand All @@ -25,39 +24,6 @@ def set_seed(seed):
torch.cuda.manual_seed_all(seed)


def init_client():
# NOTE: gcloud auth application-default loginにて事前に認証
PROJECT_ID = os.environ.get("PROJECT_ID")

client = gcs.Client(PROJECT_ID)
if client is None:
logger.error("Failed to initialize client.")
exit(1)
else:
return client


def get_bucket(client: gcs.Client, bucket_name: str):
bucket = client.bucket(bucket_name)

if bucket.exists():
return bucket
else:
raise ValueError(f"Failed to {bucket_name} does not exist.")


def get_blobs(bucket: Bucket, blob_name: str):
# blobsの中身に対するエラーハンドリング
try:
blobs = list(bucket.list_blobs(prefix=blob_name))
if len(blobs) == 0: # 最初の要素がない場合、イテレータは空
raise ValueError(f"No blobs found with prefix '{blob_name}' in the bucket.")
else:
return blobs
except Exception as e:
raise ValueError(f"Failed to get blobs from '{blob_name}' due to an error: {e}")


def get_child_id(blob_name: str):
# UUIDの正規表現パターン
uuid_pattern = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
Expand Down
22 changes: 20 additions & 2 deletions machine_learning/src/proto-gen/machine_learning/v1/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from face_detect_model.main import (
main as train_fn,
)
from face_detect_model.pred import (
main as pred_fn,
)


class HealthCheckServiceServer(
Expand All @@ -37,8 +40,23 @@ class MachineLearningServiceServicer(
machine_learning_pb2_grpc.MachineLearningServiceServicer
):
# TODO: implement Predict
def Predict(self, request: machine_learning_pb2.PredRequest, context):
pass
def Predict(self, request_iterator: machine_learning_pb2.PredRequest, context):
for req in request_iterator:
parser = argparse.ArgumentParser()
args = parser.parse_args()

args.bus_id = req.bus_id
args.bus_type = req.bus_type
args.video_type = req.video_type
args.video_chunk = req.video_chunk
args.timestamp = req.timestamp

try:
child_ids = pred_fn(args)
except Exception as e:
logging.error(e)
child_ids = []
yield machine_learning_pb2.PredResponse(child_ids=child_ids)

def Train(self, request: machine_learning_pb2.TrainRequest, context):
parser = argparse.ArgumentParser()
Expand Down

0 comments on commit 02d14fd

Please sign in to comment.