-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #37 from allenai/objathor
Refactor to use `objathor` and `black` formatting.
- Loading branch information
Showing
35 changed files
with
6,041 additions
and
3,458 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
name: Continuous integration | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
branches: | ||
- main | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: psf/black@stable | ||
tests: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: ['3.10'] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install | ||
run: | | ||
python3 -m venv .env | ||
source .env/bin/activate | ||
make install | ||
- name: Unit tests | ||
run: | | ||
source .env/bin/activate | ||
make test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
name: Release | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
jobs: | ||
deploy: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions-ecosystem/action-regex-match@v2 | ||
id: regex-match | ||
with: | ||
text: ${{ github.event.head_commit.message }} | ||
regex: '^Release ([^ ]+)' | ||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.10' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install setuptools wheel twine | ||
- name: Release | ||
if: ${{ steps.regex-match.outputs.match != '' }} | ||
uses: softprops/action-gh-release@v1 | ||
with: | ||
tag_name: ${{ steps.regex-match.outputs.group1 }} | ||
- name: Build and publish | ||
if: ${{ steps.regex-match.outputs.match != '' }} | ||
env: | ||
TWINE_USERNAME: __token__ | ||
TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} | ||
run: | | ||
python setup.py sdist bdist_wheel | ||
twine upload dist/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
install: ## [Local development] Upgrade pip, install requirements, install package. | ||
python -m pip install -U pip | ||
python -m pip install -e . | ||
|
||
install-dev: ## [Local development] Install requirements | ||
python -m pip install -r requirements.txt | ||
|
||
black: ## [Local development] Auto-format python code using black | ||
python -m black . | ||
|
||
test: ## [Local development] Run unit tests | ||
python -m pytest -x -s -v tests | ||
|
||
.PHONY: help | ||
|
||
help: # Run `make help` to get help on the make commands | ||
@grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
ABS_PATH_OF_HOLODECK = os.path.abspath(os.path.dirname(Path(__file__))) | ||
|
||
ASSETS_VERSION = os.environ.get("ASSETS_VERSION", "2023_09_23") | ||
HD_BASE_VERSION = os.environ.get("HD_BASE_VERSION", "2023_09_23") | ||
|
||
OBJATHOR_ASSETS_BASE_DIR = os.environ.get( | ||
"OBJATHOR_ASSETS_BASE_DIR", os.path.expanduser(f"~/.objathor-assets") | ||
) | ||
|
||
OBJATHOR_VERSIONED_DIR = os.path.join(OBJATHOR_ASSETS_BASE_DIR, ASSETS_VERSION) | ||
OBJATHOR_ASSETS_DIR = os.path.join(OBJATHOR_VERSIONED_DIR, "assets") | ||
OBJATHOR_FEATURES_DIR = os.path.join(OBJATHOR_VERSIONED_DIR, "features") | ||
OBJATHOR_ANNOTATIONS_PATH = os.path.join(OBJATHOR_VERSIONED_DIR, "annotations.json.gz") | ||
|
||
HOLODECK_BASE_DATA_DIR = os.path.join( | ||
OBJATHOR_ASSETS_BASE_DIR, "holodeck", HD_BASE_VERSION | ||
) | ||
|
||
HOLODECK_THOR_FEATURES_DIR = os.path.join(HOLODECK_BASE_DATA_DIR, "thor_object_data") | ||
HOLODECK_THOR_ANNOTATIONS_PATH = os.path.join( | ||
HOLODECK_BASE_DATA_DIR, "thor_object_data", "annotations.json.gz" | ||
) | ||
|
||
if ASSETS_VERSION > "2023_09_23": | ||
THOR_COMMIT_ID = "8524eadda94df0ab2dbb2ef5a577e4d37c712897" | ||
else: | ||
THOR_COMMIT_ID = "3213d486cd09bcbafce33561997355983bdf8d1a" | ||
|
||
# LLM_MODEL_NAME = "gpt-4-1106-preview" | ||
LLM_MODEL_NAME = "gpt-4o-2024-05-13" | ||
|
||
DEBUGGING = os.environ.get("DEBUGGING", "0").lower() in ["1", "true", "True", "t", "T"] |
Empty file.
107 changes: 68 additions & 39 deletions
107
modules/ceiling_objects.py → ai2holodeck/generation/ceiling_objects.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,114 +1,143 @@ | ||
import re | ||
import copy | ||
import re | ||
|
||
import torch | ||
from colorama import Fore | ||
import torch.nn.functional as F | ||
import modules.prompts as prompts | ||
from langchain import PromptTemplate | ||
from colorama import Fore | ||
from langchain import PromptTemplate, OpenAI | ||
from shapely.geometry import Polygon | ||
|
||
|
||
class CeilingObjectGenerator(): | ||
def __init__(self, llm, object_retriever): | ||
self.json_template = {"assetId": None, "id": None, "kinematic": True, | ||
"position": {}, "rotation": {}, "material": None, "roomId": None} | ||
import ai2holodeck.generation.prompts as prompts | ||
from ai2holodeck.generation.objaverse_retriever import ObjathorRetriever | ||
from ai2holodeck.generation.utils import get_bbox_dims, get_annotations | ||
|
||
|
||
class CeilingObjectGenerator: | ||
def __init__(self, object_retriever: ObjathorRetriever, llm: OpenAI): | ||
self.json_template = { | ||
"assetId": None, | ||
"id": None, | ||
"kinematic": True, | ||
"position": {}, | ||
"rotation": {}, | ||
"material": None, | ||
"roomId": None, | ||
} | ||
self.llm = llm | ||
self.object_retriever = object_retriever | ||
self.database = object_retriever.database | ||
self.ceiling_template = PromptTemplate(input_variables=["input", "rooms", "additional_requirements"], | ||
template=prompts.ceiling_selection_prompt) | ||
|
||
self.ceiling_template = PromptTemplate( | ||
input_variables=["input", "rooms", "additional_requirements"], | ||
template=prompts.ceiling_selection_prompt, | ||
) | ||
|
||
def generate_ceiling_objects(self, scene, additional_requirements_ceiling="N/A"): | ||
room_types = [room["roomType"] for room in scene["rooms"]] | ||
room_types_str = str(room_types).replace("'", "")[1:-1] | ||
ceiling_prompt = self.ceiling_template.format(input=scene["query"], | ||
rooms=room_types_str, | ||
additional_requirements=additional_requirements_ceiling) | ||
ceiling_prompt = self.ceiling_template.format( | ||
input=scene["query"], | ||
rooms=room_types_str, | ||
additional_requirements=additional_requirements_ceiling, | ||
) | ||
|
||
if "raw_ceiling_plan" not in scene: raw_ceiling_plan = self.llm(ceiling_prompt) | ||
else: raw_ceiling_plan = scene["raw_ceiling_plan"] | ||
if "raw_ceiling_plan" not in scene: | ||
raw_ceiling_plan = self.llm(ceiling_prompt) | ||
else: | ||
raw_ceiling_plan = scene["raw_ceiling_plan"] | ||
|
||
print(f"\nUser: {ceiling_prompt}\n") | ||
print(f"{Fore.GREEN}AI: Here is the ceiling plan:\n{raw_ceiling_plan}{Fore.RESET}") | ||
print( | ||
f"{Fore.GREEN}AI: Here is the ceiling plan:\n{raw_ceiling_plan}{Fore.RESET}" | ||
) | ||
|
||
ceiling_objects = [] | ||
parsed_ceiling_plan = self.parse_ceiling_plan(raw_ceiling_plan) | ||
for room_type, ceiling_object_description in parsed_ceiling_plan.items(): | ||
room = self.get_room_by_type(scene["rooms"], room_type) | ||
|
||
if room is None: | ||
print("Room type {} not found in scene.".format(room_type)) | ||
print(f"Room type {room_type} not found in scene.") | ||
continue | ||
|
||
ceiling_object_id = self.select_ceiling_object(ceiling_object_description) | ||
if ceiling_object_id is None: continue | ||
if ceiling_object_id is None: | ||
continue | ||
|
||
# Temporary solution: place at the center of the room | ||
dimension = self.database[ceiling_object_id]['assetMetadata']['boundingBox'] | ||
dimension = get_bbox_dims(self.database[ceiling_object_id]) | ||
|
||
floor_polygon = Polygon(room["vertices"]) | ||
x = floor_polygon.centroid.x | ||
z = floor_polygon.centroid.y | ||
y = scene["wall_height"] - dimension["y"] / 2 | ||
|
||
ceiling_object = copy.deepcopy(self.json_template) | ||
ceiling_object["assetId"] = ceiling_object_id | ||
ceiling_object["id"] = f"ceiling ({room_type})" | ||
ceiling_object["position"] = {"x": x, "y": y, "z": z} | ||
ceiling_object["rotation"] = {"x": 0, "y": 0, "z": 0} | ||
ceiling_object["roomId"] = room["id"] | ||
ceiling_object["object_name"] = self.database[ceiling_object_id]["annotations"]["category"] | ||
ceiling_object["object_name"] = get_annotations( | ||
self.database[ceiling_object_id] | ||
)["category"] | ||
ceiling_objects.append(ceiling_object) | ||
|
||
return raw_ceiling_plan, ceiling_objects | ||
|
||
|
||
def parse_ceiling_plan(self, raw_ceiling_plan): | ||
plans = [plan.lower() for plan in raw_ceiling_plan.split("\n") if "|" in plan] | ||
parsed_plans = {} | ||
for plan in plans: | ||
# remove index | ||
pattern = re.compile(r'^\d+\.\s*') | ||
plan = pattern.sub('', plan) | ||
if plan[-1] == ".": plan = plan[:-1] # remove the last period | ||
pattern = re.compile(r"^\d+\.\s*") | ||
plan = pattern.sub("", plan) | ||
if plan[-1] == ".": | ||
plan = plan[:-1] # remove the last period | ||
|
||
room_type, ceiling_object_description = plan.split("|") | ||
room_type = room_type.strip() | ||
ceiling_object_description = ceiling_object_description.strip() | ||
if room_type not in parsed_plans: # only consider one type of ceiling object for each room | ||
if ( | ||
room_type not in parsed_plans | ||
): # only consider one type of ceiling object for each room | ||
parsed_plans[room_type] = ceiling_object_description | ||
return parsed_plans | ||
|
||
|
||
def get_room_by_type(self, rooms, room_type): | ||
for room in rooms: | ||
if room["roomType"] == room_type: | ||
return room | ||
return None | ||
|
||
|
||
def select_ceiling_object(self, description): | ||
candidates = self.object_retriever.retrieve([f"a 3D model of {description}"], threshold=29) | ||
ceiling_candiates = [candidate for candidate in candidates if self.database[candidate[0]]["annotations"]["onCeiling"] == True] | ||
candidates = self.object_retriever.retrieve( | ||
[f"a 3D model of {description}"], threshold=29 | ||
) | ||
ceiling_candiates = [ | ||
candidate | ||
for candidate in candidates | ||
if get_annotations(self.database[candidate[0]])["onCeiling"] == True | ||
] | ||
|
||
valid_ceiling_candiates = [] | ||
for candidate in ceiling_candiates: | ||
dimension = self.database[candidate[0]]['assetMetadata']['boundingBox'] | ||
if dimension["y"] <= 1.0: valid_ceiling_candiates.append(candidate) | ||
dimension = get_bbox_dims(self.database[candidate[0]]) | ||
if dimension["y"] <= 1.0: | ||
valid_ceiling_candiates.append(candidate) | ||
|
||
if len(valid_ceiling_candiates) == 0: | ||
print("No ceiling object found for description: {}".format(description)) | ||
return None | ||
|
||
selected_ceiling_object_id = self.random_select(valid_ceiling_candiates)[0] | ||
return selected_ceiling_object_id | ||
|
||
|
||
def random_select(self, candidates): | ||
scores = [candidate[1] for candidate in candidates] | ||
scores_tensor = torch.Tensor(scores) | ||
probas = F.softmax(scores_tensor, dim=0) # TODO: consider using normalized scores | ||
probas = F.softmax( | ||
scores_tensor, dim=0 | ||
) # TODO: consider using normalized scores | ||
selected_index = torch.multinomial(probas, 1).item() | ||
selected_candidate = candidates[selected_index] | ||
return selected_candidate | ||
return selected_candidate |
Oops, something went wrong.