You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.0 KiB
Python
43 lines
1.0 KiB
Python
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import util
|
|
util.config_gpu()
|
|
|
|
RESAMPLE_NUM = 10
|
|
|
|
model_file = "./algos/models/deep_annotation_inference.h5"
|
|
|
|
model = tf.keras.models.load_model(model_file)
|
|
model.summary()
|
|
|
|
NUM_POINT=512
|
|
|
|
def sample_one_obj(points, num):
|
|
if points.shape[0] < NUM_POINT:
|
|
return np.concatenate([points, np.zeros((NUM_POINT-points.shape[0], 3), dtype=np.float32)], axis=0)
|
|
else:
|
|
idx = np.arange(points.shape[0])
|
|
np.random.shuffle(idx)
|
|
return points[idx[0:num]]
|
|
|
|
def predict_yaw(points):
|
|
points = np.array(points).reshape((-1,3))
|
|
input_data = np.stack([x for x in map(lambda x: sample_one_obj(points, NUM_POINT), range(RESAMPLE_NUM))], axis=0)
|
|
pred_val = model.predict(input_data)
|
|
pred_cls = np.argmax(pred_val, axis=-1)
|
|
print(pred_cls)
|
|
|
|
ret = (pred_cls[0]*3+1.5)*np.pi/180.
|
|
ret =[0,0,ret]
|
|
print(ret)
|
|
|
|
return ret
|
|
|
|
# warmup the model
|
|
predict_yaw(np.random.random([1000,3]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
predict_yaw(np.random.random([1000,3])) |