Tensorflow: setting default session using as_defau

2019-07-07 17:08发布

问题:

I write

import tensorflow as tf
x = tf.Session()
x.as_default().__enter__()
print(tf.get_default_session()) # prints "None"

Why am I not accomplishing getting x to be the default session? I know I could just do it inside a "with" block, but I'm wondering why this isn't working.

Note that if I write

import tensorflow as tf
with tf.Session().as_default():
    print(tf.get_default_session()) # shows <tensorflow.python.client.session.Session object at 0x114217a90> 

How are these two pieces of code doing different things?

Also, if I just write

import tensorflow as tf
tf.Session()

will that create a nameless session that I have no way to close, so it will just run until I restart my kernel? Is there a way to check which sessions are currently open?

回答1:

Simple fix:

import tensorflow as tf
x = tf.Session().__enter__()
print(tf.get_default_session())

Result:

<tensorflow.python.client.session.Session object at 0x7f6855cbafd0>

Cause:

as_default() is returning a context manager, not a session, you're calling enter on a _GeneratorContextManager object when you mean to enter a Session object.

>>> tf.Session().as_default()
<contextlib._GeneratorContextManager object at 0x7f6820805a58>
>>> tf.Session()
<tensorflow.python.client.session.Session object at 0x7f6820805898>

Update

To answer your (initially perplexing) follow up question:

What you are doing with the with statement is entering and exiting the context manager. This is causing the default session to be set and unset. But it is not opening and closing your session (it appears, this was confusing to me and I'm only seeing it after some experimentation). Try this code out to see it operate:

>>> print(tf.get_default_session())
None

>>> x = tf.Session()
>>> print(tf.get_default_session())
None

>>> with x.as_default():
...     print(tf.get_default_session())
... 
<tensorflow.python.client.session.Session object at 0x7f09eb9fb550>
>>> print(x)
<tensorflow.python.client.session.Session object at 0x7f09eb9fb550>

We see at the end of those statements that your session never closed, but along the way we see that the default session was set and unset as expected.

Using the standard with statement both closes the session and sets/unsets the default session.

>>> with tf.Session() as sess:
...     print(tf.get_default_session())
<tensorflow.python.client.session.Session object at 0x7f09eb9fbbe0>