gel.scripts.get_displacement_from_gpr

Within the FM-TRACK environment, perform GPR-interpolation with text input and output.

 1#!/usr/bin/env python3
 2"""Within the FM-TRACK environment, perform GPR-interpolation with text
 3input and output."""
 4import argparse
 5import numpy as np
 6import os
 7import pickle
 8
 9
10_cache = dict()
11def get_predicted_u(gpr_path, vertices):
12    # loads GPR models
13    if gpr_path not in _cache:
14        gp_U = pickle.load(open(os.path.join(gpr_path, 'gp_U_cleaned.sav'), 'rb'))
15        gp_V = pickle.load(open(os.path.join(gpr_path, 'gp_V_cleaned.sav'), 'rb'))
16        gp_W = pickle.load(open(os.path.join(gpr_path, 'gp_W_cleaned.sav'), 'rb'))
17        scaler = pickle.load(open(os.path.join(gpr_path, 'scaler_cleaned.sav'),'rb'))
18        _cache[gpr_path] = (gp_U, gp_V, gp_W, scaler)
19    else:
20        (gp_U, gp_V, gp_W, scaler) = _cache[gpr_path]
21
22    input_arr = scaler.transform(vertices)
23
24    N = 310000
25    if len(input_arr) < N:
26        u = gp_U.predict(input_arr)
27        v = gp_V.predict(input_arr)
28        w = gp_W.predict(input_arr)
29        disp = np.column_stack((u,v,w))
30    else:
31        # Must stage it
32        disp = np.zeros_like(input_arr)
33        for batch_i in range(int(np.ceil(len(input_arr)/N))):
34            u = gp_U.predict(input_arr[N*batch_i:N*(batch_i+1)])
35            v = gp_V.predict(input_arr[N*batch_i:N*(batch_i+1)])
36            w = gp_W.predict(input_arr[N*batch_i:N*(batch_i+1)])
37            disp[N*batch_i:N*(batch_i+1)] = np.column_stack((u,v,w))
38
39    return disp
40
41
42def get_u_from_gpr_main():
43    parser = argparse.ArgumentParser(
44        description="Evaluate GPR model at provided vertices, for use "
45        "in FM-Track environment"
46    )
47    parser.add_argument(
48        "-v",
49        "--vertices-file",
50        type=str,
51        metavar="VERTICES_FILE"
52    )
53    parser.add_argument(
54        "-d",
55        "--dest",
56        type=str,
57        metavar="DEST"
58    )
59    parser.add_argument(
60        "-g",
61        "--gpr-dir",
62        type=str,
63        metavar="GPR_DIR"
64    )
65    args = parser.parse_args()
66
67    vertices = np.loadtxt(args.vertices_file)
68    u = get_predicted_u(args.gpr_dir, vertices)
69    np.savetxt(args.dest, u)
70
71
72if __name__=="__main__":
73    get_u_from_gpr_main()
def get_predicted_u(gpr_path, vertices):
12def get_predicted_u(gpr_path, vertices):
13    # loads GPR models
14    if gpr_path not in _cache:
15        gp_U = pickle.load(open(os.path.join(gpr_path, 'gp_U_cleaned.sav'), 'rb'))
16        gp_V = pickle.load(open(os.path.join(gpr_path, 'gp_V_cleaned.sav'), 'rb'))
17        gp_W = pickle.load(open(os.path.join(gpr_path, 'gp_W_cleaned.sav'), 'rb'))
18        scaler = pickle.load(open(os.path.join(gpr_path, 'scaler_cleaned.sav'),'rb'))
19        _cache[gpr_path] = (gp_U, gp_V, gp_W, scaler)
20    else:
21        (gp_U, gp_V, gp_W, scaler) = _cache[gpr_path]
22
23    input_arr = scaler.transform(vertices)
24
25    N = 310000
26    if len(input_arr) < N:
27        u = gp_U.predict(input_arr)
28        v = gp_V.predict(input_arr)
29        w = gp_W.predict(input_arr)
30        disp = np.column_stack((u,v,w))
31    else:
32        # Must stage it
33        disp = np.zeros_like(input_arr)
34        for batch_i in range(int(np.ceil(len(input_arr)/N))):
35            u = gp_U.predict(input_arr[N*batch_i:N*(batch_i+1)])
36            v = gp_V.predict(input_arr[N*batch_i:N*(batch_i+1)])
37            w = gp_W.predict(input_arr[N*batch_i:N*(batch_i+1)])
38            disp[N*batch_i:N*(batch_i+1)] = np.column_stack((u,v,w))
39
40    return disp
def get_u_from_gpr_main():
43def get_u_from_gpr_main():
44    parser = argparse.ArgumentParser(
45        description="Evaluate GPR model at provided vertices, for use "
46        "in FM-Track environment"
47    )
48    parser.add_argument(
49        "-v",
50        "--vertices-file",
51        type=str,
52        metavar="VERTICES_FILE"
53    )
54    parser.add_argument(
55        "-d",
56        "--dest",
57        type=str,
58        metavar="DEST"
59    )
60    parser.add_argument(
61        "-g",
62        "--gpr-dir",
63        type=str,
64        metavar="GPR_DIR"
65    )
66    args = parser.parse_args()
67
68    vertices = np.loadtxt(args.vertices_file)
69    u = get_predicted_u(args.gpr_dir, vertices)
70    np.savetxt(args.dest, u)