続き
int get_box(double x,double x_dot,double theta,double theta_dot)
{
int box=0;
if (x < -5.0 ||
x > 5.0 ||theta > 2*pai ||theta < -720*2*pai/360) return(-1);
if (x < -2.0) box = 0;
else if (x < 0.0) box = 1;
else box = 2;
if (x_dot < -1.5) ;
else if (x_dot < 0.0) box += 3;
else box += 6;
if ((cos(theta) < 0.0) && (cos(theta) > -1.0)
&& (sin(theta) < 1.0) && (sin(theta) > 0.0)) box += 9;
else if ((cos(theta) > -1.0) && (cos(theta) < 0.0)
&& (sin(theta) < 0.0) && (sin(theta) > -1.0)) box += 18;
else if ((cos(theta) > 0.0) && (cos(theta) < 1/sqrt(2))
&& (sin(theta) > -1.0) && (sin(theta) < -1/sqrt(2))) box += 27;
else if ((cos(theta) > 1/sqrt(2)) && (cos(theta) < 2/sqrt(3))
&& (sin(theta) > -1/sqrt(2)) && (sin(theta) < -1/2)) box += 36;
else if ((cos(theta) > 2/sqrt(3)) && (cos(theta) < 0.965925826)
&& (sin(theta) > -1/2) && (sin(theta) < -0.258819045)) box += 45;
else if ((cos(theta) > 0.965925826) && (cos(theta) < 1)
&& (sin(theta) > -0.258819045) && (sin(theta) < 0)) box += 54;
// else if ((cos(theta) == 1.0) && (sin(theta) == 0.0)) box += 54;
else if ((cos(theta) < 1.0) && (cos(theta) > 0.965925826)
&& (sin(theta) > 0)&&(sin(theta) < 0.258819045)) box += 63;
else if ((cos(theta) < 0.965925826) && (cos(theta) > 2/sqrt(3))
&& (sin(theta) > 0.258819045) && (sin(theta) < 1/2)) box += 72;
else if ((cos(theta) < 2/sqrt(3)) && (cos(theta) > 1/sqrt(2))
&& (sin(theta) > 1/2) && (sin(theta) < 1/sqrt(2))) box += 81;
else if ((cos(theta) < 1/sqrt(2)) && (cos(theta) > 0.0)
&& (sin(theta) > 1/sqrt(2)) && (sin(theta) < 1.0)) box += 90;
if (theta_dot < -2*pai) box += 99;
else if ((theta_dot > -2*pai)&&(theta_dot < -pai)) box += 198;
else if ((theta_dot > -pai)&&(theta_dot < -thirty_degrees)) box += 297;
else if ((theta_dot > -thirty_degrees)&&(theta_dot < -twelve_degrees)) box += 396;
else if ((theta_dot > -twelve_degrees)&&(theta_dot < 0)) box += 494;
else if ((theta_dot > 0)&&(theta_dot < twelve_degrees)) box += 593;
else if ((theta_dot > twelve_degrees)&&(theta_dot < thirty_degrees)) box += 692;
else if ((theta_dot > thirty_degrees)&&(theta_dot < pai)) box += 791;
else if ((theta_dot > pai)&&(theta_dot < 2*pai)) box += 890;
else box+=989;
return(box);
}
void init(void)
{
GLfloat light0_position[] = {0.0,0.0,1.0,0.0};/* 照明の位置 */
GLfloat light1_position[] ={100.0,100.0,0.0,1.0};/* 照明の位置 */
GLfloat light0_diffuse[] = {0.8,0.8,0.8,1.0};/* 拡散成分 */
GLfloat light1_diffuse[] = {0.5,0.5,0.5,1.0};/* 拡散成分 */
GLfloat light_specular[] = {0.2,0.2,0.2,1.0};/* 鏡面成分 */
GLfloat lmodel_ambient[] = {0.1,0.1,0.1,1.0};/* 環境光 */
glLightfv(GL_LIGHT0, GL_POSITION, light0_position);
glLightfv(GL_LIGHT0, GL_DIFFUSE, light0_diffuse);
glLightfv(GL_LIGHT0, GL_SPECULAR, light_specular);
/* 照明No.1 */
glLightfv(GL_LIGHT1, GL_POSITION, light1_position);
glLightfv(GL_LIGHT1, GL_DIFFUSE, light1_diffuse);
glLightfv(GL_LIGHT1, GL_SPECULAR, light_specular);
glLightModelfv(GL_LIGHT_MODEL_AMBIENT, lmodel_ambient);
glEnable(GL_LIGHTING);
glEnable(GL_LIGHT0);
glEnable(GL_LIGHT1);
glDepthFunc(GL_LEQUAL); /*デプスバッファを有効化*/
glEnable(GL_DEPTH_TEST);
}
void display(void)
{
GLfloat material_color0[4] = {1.0,0.0,0.0,1.0};
GLfloat material_color1[4] = {0.0,1.0,0.0,1.0};
GLfloat material_color2[4] = {1.0,1.0,1.0,1.0};
GLfloat material_color3[4] = {1.0,1.0,0.0,1.0};
GLfloat material_specular[4] = {0.2,0.2,0.2,1.0};
glClear (GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
/*鏡面反射成分のセット*/
glMaterialfv(GL_FRONT, GL_SPECULAR, material_specular);
/*初期設定*/
glTranslatef (0.0, 0.0, 0.0);
/*台車*/
glMaterialfv(GL_FRONT, GL_DIFFUSE, material_color0);
glMaterialf(GL_FRONT, GL_SHININESS, 10.0);
glPushMatrix();
glTranslatef (0.0, 0.0, 0.0);
glTranslatef (q, 0.0, r);
glPushMatrix();
glScalef (0.8, 0.8, 0.5);
glutSolidCube(1.0);
glPopMatrix();
/*棒*/
//glPushMatrix();
glMaterialfv(GL_FRONT, GL_DIFFUSE, material_color1);
glMaterialf(GL_FRONT, GL_SHININESS, 10.0);
glTranslatef (0.0, 0.0, 0.0); /* 図形表示位置 */
glRotatef ((GLfloat) shoulder,0.0, 0.0, 1.0);
glTranslatef (0.0, -1.4, 0.0);
glPushMatrix();
glScalef (0.5, 2.9, 0.5);
glutSolidCube(1.0);
glPopMatrix();
glPopMatrix();
glutSwapBuffers();
}
void reshape (int w, int h)
{
glViewport (10, 0, (GLsizei) w, (GLsizei) h);
glMatrixMode (GL_PROJECTION);
glLoadIdentity ();
gluPerspective(65.0, (GLfloat) w/(GLfloat) h, 1.0, 20.0);
glMatrixMode(GL_MODELVIEW);
glLoadIdentity();
glTranslatef (0.0, -1.0, -7.0);
glRotatef (40, 1.0, 0.0, 0.0);
}
void keyboard (unsigned char key, int x, int y)
{
double z;
static int i;
z=0;
switch (key){
case '3':
if(success>0 && success<MAX_STEPS && data[trial-1][success].x == 0 && data[trial-1][success].x_dot ==0 &&
data[trial-1][success].theta ==0 && data[trial-1][success].theta_dot == 0){
q = data[trial-1][success-1].x;
shoulder = -180-data[trial-1][success-1].theta/pai*180;
glutPostRedisplay();
trial = trial+1;
i = 0;
break;
}
else{
q = data[trial-1][success].x;
shoulder = -180 - data[trial-1][success].theta/pai*180;
printf("theta=%lf,x=%lf,x_dot=%lf,theta_dot=%lf\n",data[trial-1][success].theta/pai*180,
data[trial-1][success].x,data[trial-1][success].x_dot,data[trial-1][success].theta_dot);
glutPostRedisplay();
if(i<MAX_STEPS){
success = success+1;
}
else{
trial = trial+1;
success = 0;
}
break;
}
case 'x':
printf("何回目?");
scanf("%d",&trial);
success = 0;
break;
}
}
int main(int argc, char** argv)
{
int box,i,j;
double x, x_dot, theta, theta_dot, failures;
printf("学習率 ALPHA %.3f\n", ALPHA);
printf("割引率 GAMMA %.3f\n", GAMMA);
x = x_dot = theta_dot = 0;
theta = 0;
srand(time(NULL));
success = 0;
trial = 0;
failures = 0.0;
// srand(RND_SEED);
reset_controller();
for (i = 0; i < NUM_BOXES; i++)for (j = 0; j < 2; j++)q_val[i][j] = 0.0; /*初期化*/
while (success<MAX_STEPS)
{
if(first_time2==1){
first_time2=0;
prev_state=get_box(x,x_dot,theta,theta_dot);
// srand(time(NULL));
}
else{
prev_state=box;
// prev_action=cur_action;
}
prev_action=get_action(x, x_dot, theta, theta_dot);
cart_pole(prev_action,&x,&x_dot,&theta,&theta_dot);
box=get_box(x,x_dot,theta,theta_dot);
data[trial][success].x = x; /*受け渡し*/
data[trial][success].x_dot = x_dot;
data[trial][success].theta = theta;
data[trial][success].theta_dot = theta_dot;
if (box==-1) /*失敗*/
{
failures=-1000.0; /*失敗時の報酬*/
Q_update(prev_state,box,prev_action,cur_action,failures);
reset_controller();
x=x_dot=theta_dot=0;/*次の初期状態を設定*/
theta = 0;
trial++; /*次のtrialへ*/
printf("At %d success ,try %d trials, %f\n",success,trial,data[trial-1][success-1].theta); /*結果を表示*/
success=0; /*成功success数を初期化*/
first_time2=1;
if(trial>499){
printf("失敗\n");
return 0;
// break;
}
}
else{ /*1サクセス成功*/
if(cos(theta) > 0){
failures = (1+cos(theta)/2) ;
}
else if(cos(theta) < 0)failures = -1.0;
cur_action = top_action(x, x_dot, theta, theta_dot);
Q_update(prev_state,box,prev_action,cur_action,failures);
// fprintf(fp2,"%lf\n",failures);
// fprintf(fp2,"%lf\n",q_val[prev_state][prev_action]);
success++;
}
}
// for (i=0;i<NUM_BOXES;i++)fprintf(fp2,"%g %f\n",q_val[i][0],q_val[i][1]);
printf("\n");
printf("%d回目でsuccess=%d回に達しました。\n",trial,success);
printf("何回目?");
scanf("%d",&trial);
glutInit(&argc, argv);
glutInitDisplayMode (GLUT_DOUBLE | GLUT_RGB | GLUT_DEPTH);
glutInitWindowSize (1000, 320);
glutInitWindowPosition (100, 100);
glutCreateWindow ("robot2");
init ();
glutDisplayFunc(display);
glutReshapeFunc(reshape);
glutKeyboardFunc(keyboard);
glutMainLoop();
// fclose(fp2);
return 0;
}
int get_action(double x,
double x_dot,
double theta,
double theta_dot)
{
double drand;
drand = (double)rand() / (double)RAND_MAX;
cur_state = get_box(x, x_dot, theta, theta_dot); /* どんな状態かをもらう*/
if (drand <= EPSILON){ /* 確率 EPSILON 以下ならランダム行動 */
action = rand()%2;
}
/* Now determine best action */
else if (q_val[cur_state][0] <= q_val[cur_state][1]){
action = 1;
}
else{
action = 0;
}
return action;
}
int top_action(double x,
double x_dot,
double theta,
double theta_dot)
{
top_state = get_box(x, x_dot, theta, theta_dot); /* どんな状態かをもらう*/
/*次状態の報酬を観測する。*/
if (q_val[top_state][0] <= q_val[top_state][1]){
next_action = 1;
}
else{
next_action = 0;
}
return action;
}
void Q_update(int prev_state,int box,int prev_action,int cur_action,double failures)
{
double predicted_value;
if (box == -1) predicted_value = 0.0;
// else if (q_val[box][0] <= q_val[box][1])
else if(cur_action==1) predicted_value = q_val[box][1];
// else if(q_val[box][0] >= q_val[box][1])
else if(cur_action==0) predicted_value = q_val[box][0];
q_val[prev_state][prev_action]
+= ALPHA * (failures + GAMMA * predicted_value
- q_val[prev_state][prev_action]);
}
void reset_controller(void)
{
cur_state = prev_state = 0;
cur_action = prev_action = -1;
}
お礼
ご回答ありがとうございます。 >巨大な配列に依存しないプログラムを考えたほうが 良いのではないかと思います。 確かに、おっしゃる通りです。 巨大な配列に依存しないプログラムを考えてみすます。