👨‍💻 Deep Learning/Deep learning

Tensorflow? 텐서플로우?

Data_novice 2024. 8. 1. 17:10

Tensorflow

 

텐서플로우는 Google Brain 팀에서 개발한 오픈 소스 프레임워크입니다. 복잡한 데이터 흐름 그래프(Data Flow Graph)를 기반으로 딥러닝 모델을 구축하고 훈련할 수 있도록 하는 프레임워크로 다양한 기계 학습 작업을 지원하고 여러 분야에서 사용되고 있습니다.

 

아마 딥러닝을 하지 않더라도 머신러닝에 관심이 있는 분이라면 한번쯤 들어보셨을 것 같습니다. 저도 계속해서 듣기만했지 처음으로 써보려고 합니다.

 

TensorFlow 설치

pip install --upgrade tensorflow

# 버전 확인
import tensorflow as tf
tf.__version__

 

 

개념

 

1. 텐서(Tensor) : 텐서플로우에서 데이터를 나타내는 기본 단위라고 볼 수 있습니다. 다차원 배열로 이루어져 있으며 숫자의 배열로 생각할 수 있습니다. 또한 텐서는 데이터를 저장하고 연산하는 데 사용됩니다.

 

또한 텐서는 Shape, Type, Rank라는 속성을 가집니다.

 

속성 설명 코드 예제
Shape 텐서의 각 차원의 크기를 나타내는 튜플 (3,), (2,2), (3,4,5) ...  tf.constant([1,2,3]).shape
> (3,)

tf.constant([[1,2,3], [4,5,6]]).shape
> (2, 3)
Rank 텐서의 차원 수를 나타냄 0 : 스칼라, 1 : 벡터, 2 : 행렬, 3 : 3D 텐서 ... tf.rank(tf.constant([[1,2], [3,4]])
> 2
Type 텐서의 데이터 타입을 나타냄 tf.int32
tf.float32
tf.string
tf.constant(3.0).dtype
> tf.float32

 

2. 데이터 흐름 그래프(Data Flow Graph) : 텐서플로우는 데이터 흐름 그래프를 사용하여 계산을 수행합니다. 노드(Node)는 연산(Operation)을 나타내고, 엣지(Edge)는 데이터(tensor)의 흐름을 나타냅니다. 해당 구조는 복잡한 계산을 효율적으로 표현하고 실행할 수 있도록 도와줍니다.

 

그럼 코드로 한번 살펴봅시다.

 

상수 텐서

# 실제로는 문자열이 들어간 노드를 만드는것

hello = tf.constant("Hello, TensorFlow!")

# run the op and get result
print(hello.numpy()) # b'String'일 때 b는 Bytes literals라는 뜻

> b'Hello, TensorFlow!'

 

위와 같이 constant는 말 그대로 상수 텐서를 생성하는 함수 입니다. 상수 텐서는 초기화 시점에 값을 설정하고, 값이 변경되지 않는 텐서입니다. 문자열 뿐 아닌 수치형 자료도 가능합니다.

 

또한 아래와 같이 연산도 가능합니다.

 

# 결과값이 바로 나오는게 아님
node1 = tf.constant(3.0, tf.float32)
node2 = tf.constant(4.0)
node3 = tf.add(node1, node2)

print("node1:", node1, "node2:", node2)
print("node3:", node3)
node1: tf.Tensor(3.0, shape=(), dtype=float32) node2: tf.Tensor(4.0, shape=(), dtype=float32)
node3: tf.Tensor(7.0, shape=(), dtype=float32)

 

단, 위와 같이 바로 출력값으로 나오는 것이 아닙니다. Tensorflow 1.x 버전에서는 session을 따로 만들어주어 진행해야 했지만, 2.x버전 부터는 굳이 session이 필요하지 않고 단순히 .numpy()를 가지고 출력이 가능합니다.

 

print("(node1, node2) :", [node1.numpy(), node2.numpy()])
print("(node3) :", node3.numpy())
(node1, node2) : [3.0, 4.0]
(node3) : 7.0