-
Notifications
You must be signed in to change notification settings - Fork 6
/
run.py
74 lines (58 loc) · 2.5 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pdb
import argparse
import numpy as np
import matplotlib.pyplot as plt
from actions import plan_action
from agents.teleport_agent import TeleportAgent
from agents.panda_agent import PandaAgent
from block_utils import Object, Dimensions, Position, Color, get_adversarial_blocks
from particle_belief import ParticleBelief
from tower_planner import TowerPlanner
def main(args):
NOISE=0.00005
# get a bunch of random blocks
blocks = get_adversarial_blocks(num_blocks=args.num_blocks)
if args.agent == 'teleport':
agent = TeleportAgent(blocks, NOISE)
elif args.agent == 'panda':
agent = PandaAgent(blocks, NOISE, use_platform=True, teleport=False)
else:
raise NotImplementedError()
# construct a world containing those blocks
beliefs = [ParticleBelief(block,
N=200,
plot=True,
vis_sim=False,
noise=NOISE) for block in blocks]
agent._add_text('Ready?')
input('Start?')
# Gain information about the CoM of each block.
for b_ix, (block, belief) in enumerate(zip(blocks, beliefs)):
print('Running filter for', block.name)
for interaction_num in range(5):
print("Interaction number: ", interaction_num)
agent._add_text('Planning action.')
action = plan_action(belief, exp_type='reduce_var', action_type='place')
observation = agent.simulate_action(action, b_ix, T=50)
agent._add_text('Updating particle belief.')
belief.update(observation)
block.com_filter = belief.particles
print(belief.estimated_coms[-1], block.com)
# Find the tallest tower
print('Finding tallest tower.')
# agent._add_text('Planning tallest tower')
tp = TowerPlanner(plan_mode='expectation')
tallest_tower = tp.plan(blocks)
# and execute the resulting plan.
agent.simulate_tower(tallest_tower, vis=True, T=2500, save_tower=args.save_tower)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true')
parser.add_argument('--plot', action='store_true')
parser.add_argument('--num-blocks', type=int, default=3)
parser.add_argument('--save-tower', action='store_true')
parser.add_argument('--agent', choices=['teleport', 'panda'], default='teleport')
args = parser.parse_args()
if args.debug: pdb.set_trace()
# test_exploration(args)
main(args)