본문 바로가기
IT/머신러닝

[section_2_lab] 머신러닝 Linear Regression 실습

by 빨강자몽 2018. 6. 1.

Cost(w,b)가 최소가 되는 w,b 구하기(상수 이용) 

  • x, y가 각 각 1, 2, 3(상수)일때 Cost(w,b)가 최소가 되는 w,b를 구한다.
  • tf.placeholder을 변수로, tf.Variable을 trainable로 이해하면 좋다.
# -*- coding: utf-8 -*-
import tensorflow as tf

# X and Y data : 넣고자 하는 데이터 set
x_train = [1, 2, 3]
y_train = [1, 2, 3]

# 최종적으로 구하고자 하는 w,b 변수
W = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.random_normal([1]), name='bias')

# Our hypothesis XW+b
hypothesis = x_train * W + b

# cost/loss function : 각 점에서의 cost/loss 정의
cost = tf.reduce_mean(tf.square(hypothesis - y_train))

# Minimize cost(w,b)가 최소가되는 w,b를 구한다.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)

# sess을 생성한다.
sess = tf.Session()
# variable(trainable)을 생성한다.
sess.run(tf.global_variables_initializer())

# Fit the line : 2000번 반복시킨다. -> 100번에 한 번씩 현재 값들을 출력한다.
for step in range(2001):
	sess.run(train)
	if step % 100 == 0:
		print(step, sess.run(cost), sess.run(W), sess.run(b))


Cost(w,b)가 최소가 되는 w,b 구하기(변수 이용) 

# -*- coding: utf-8 -*-
import tensorflow as tf

# 구하고자 하는 variable(trainable) w,b
W = tf.Variable(tf.random_normal([1]), name='weight')
b = tf.Variable(tf.random_normal([1]), name='bias')
# x,y값을 임의로 넣기위한 placeholder(변수)
X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32, shape=[None])

# Our hypothesis XW+b
hypothesis = X * W + b
# cost/loss function : 각 점에서의 cost/loss 정의
cost = tf.reduce_mean(tf.square(hypothesis - Y))
# Minimize cost(w,b)가 최소가되는 w,b를 구한다.
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)

# sess을 생성한다.
sess = tf.Session()
# variable(trainable)을 생성한다.
sess.run(tf.global_variables_initializer())

# 2000번 반복시킨다. -> 100번에 한 번씩 현재 값들을 출력한다.
for step in range(2001):
   cost_val, W_val, b_val, _ = sess.run([cost, W, b, train],
       feed_dict={X: [1, 2, 3, 4, 5],
                  Y: [2.1, 3.1, 4.1, 5.1, 6.1]})
   if step % 20 == 0:
       print(step, cost_val, W_val, b_val)