1717Use this script to compress images with pre-trained models as published. See the
1818'models' subcommand for a list of available models.
1919
20- Currently, this script requires tensorflow-compression v1.3 .
20+ This script requires TFC v2 (`pip install tensorflow-compression==2.*`) .
2121"""
2222
2323import argparse
2424import os
2525import sys
2626import urllib
27-
2827from absl import app
2928from absl .flags import argparse_flags
30- import tensorflow .compat .v1 as tf
31-
29+ import tensorflow as tf
3230import tensorflow_compression as tfc # pylint:disable=unused-import
3331
32+
3433# Default URL to fetch metagraphs from.
3534URL_PREFIX = "https://storage.googleapis.com/tensorflow_compression/metagraphs"
3635# Default location to store cached metagraphs.
3736METAGRAPH_CACHE = "/tmp/tfc_metagraphs"
3837
3938
4039def read_png (filename ):
41- """Creates graph to load a PNG image file."""
40+ """Loads a PNG image file."""
4241 string = tf .io .read_file (filename )
4342 image = tf .image .decode_image (string )
44- image = tf .expand_dims (image , 0 )
45- return image
43+ return tf .expand_dims (image , 0 )
4644
4745
4846def write_png (filename , image ):
49- """Creates graph to write a PNG image file."""
47+ """Writes a PNG image file."""
5048 image = tf .squeeze (image , 0 )
5149 if image .dtype .is_floating :
5250 image = tf .round (image )
5351 if image .dtype != tf .uint8 :
5452 image = tf .saturate_cast (image , tf .uint8 )
5553 string = tf .image .encode_png (image )
56- return tf .io .write_file (filename , string )
54+ tf .io .write_file (filename , string )
5755
5856
5957def load_cached (filename ):
@@ -63,9 +61,9 @@ def load_cached(filename):
6361 with tf .io .gfile .GFile (pathname , "rb" ) as f :
6462 string = f .read ()
6563 except tf .errors .NotFoundError :
66- url = URL_PREFIX + "/" + filename
64+ url = f"{ URL_PREFIX } /{ filename } "
65+ request = urllib .request .urlopen (url )
6766 try :
68- request = urllib .request .urlopen (url )
6967 string = request .read ()
7068 finally :
7169 request .close ()
@@ -75,50 +73,29 @@ def load_cached(filename):
7573 return string
7674
7775
78- def import_metagraph (model ):
79- """Imports a trained model metagraph into the current graph ."""
76+ def instantiate_model_signature (model , signature ):
77+ """Imports a trained model and returns one of its signatures as a function ."""
8078 string = load_cached (model + ".metagraph" )
81- metagraph = tf .MetaGraphDef ()
79+ metagraph = tf .compat . v1 . MetaGraphDef ()
8280 metagraph .ParseFromString (string )
83- tf .train .import_meta_graph (metagraph )
84- return metagraph .signature_def
85-
86-
87- def instantiate_signature (signature_def ):
88- """Fetches tensors defined in a signature from the graph."""
89- graph = tf .get_default_graph ()
90- inputs = {
91- k : graph .get_tensor_by_name (v .name )
92- for k , v in signature_def .inputs .items ()
93- }
94- outputs = {
95- k : graph .get_tensor_by_name (v .name )
96- for k , v in signature_def .outputs .items ()
97- }
98- return inputs , outputs
81+ wrapped_import = tf .compat .v1 .wrap_function (
82+ lambda : tf .compat .v1 .train .import_meta_graph (metagraph ), [])
83+ graph = wrapped_import .graph
84+ inputs = metagraph .signature_def [signature ].inputs
85+ outputs = metagraph .signature_def [signature ].outputs
86+ inputs = [graph .as_graph_element (inputs [k ].name ) for k in sorted (inputs )]
87+ outputs = [graph .as_graph_element (outputs [k ].name ) for k in sorted (outputs )]
88+ return wrapped_import .prune (inputs , outputs )
9989
10090
10191def compress_image (model , input_image ):
102- """Compresses an image array into a bitstring."""
103- with tf .Graph ().as_default ():
104- # Load model metagraph.
105- signature_defs = import_metagraph (model )
106- inputs , outputs = instantiate_signature (signature_defs ["sender" ])
107-
108- # Just one input tensor.
109- inputs = inputs ["input_image" ]
110- # Multiple output tensors, ordered alphabetically, without names.
111- outputs = [outputs [k ] for k in sorted (outputs ) if k .startswith ("channel:" )]
112-
113- # Run encoder.
114- with tf .Session () as sess :
115- arrays = sess .run (outputs , feed_dict = {inputs : input_image })
116-
117- # Pack data into bitstring.
118- packed = tfc .PackedTensors ()
119- packed .model = model
120- packed .pack (outputs , arrays )
121- return packed .string
92+ """Compresses an image tensor into a bitstring."""
93+ sender = instantiate_model_signature (model , "sender" )
94+ tensors = sender (input_image )
95+ packed = tfc .PackedTensors ()
96+ packed .model = model
97+ packed .pack (tensors )
98+ return packed .string
12299
123100
124101def compress (model , input_file , output_file , target_bpp = None , bpp_strict = False ):
@@ -127,10 +104,8 @@ def compress(model, input_file, output_file, target_bpp=None, bpp_strict=False):
127104 output_file = input_file + ".tfci"
128105
129106 # Load image.
130- with tf .Graph ().as_default ():
131- with tf .Session () as sess :
132- input_image = sess .run (read_png (input_file ))
133- num_pixels = input_image .shape [- 2 ] * input_image .shape [- 3 ]
107+ input_image = read_png (input_file )
108+ num_pixels = input_image .shape [- 2 ] * input_image .shape [- 3 ]
134109
135110 if not target_bpp :
136111 # Just compress with a specific model.
@@ -175,27 +150,12 @@ def decompress(input_file, output_file):
175150 """Decompresses a TFCI file and writes a PNG file."""
176151 if not output_file :
177152 output_file = input_file + ".png"
178-
179- with tf .Graph ().as_default ():
180- # Unserialize packed data from disk.
181- with tf .io .gfile .GFile (input_file , "rb" ) as f :
182- packed = tfc .PackedTensors (f .read ())
183-
184- # Load model metagraph.
185- signature_defs = import_metagraph (packed .model )
186- inputs , outputs = instantiate_signature (signature_defs ["receiver" ])
187-
188- # Multiple input tensors, ordered alphabetically, without names.
189- inputs = [inputs [k ] for k in sorted (inputs ) if k .startswith ("channel:" )]
190- # Just one output operation.
191- outputs = write_png (output_file , outputs ["output_image" ])
192-
193- # Unpack data.
194- arrays = packed .unpack (inputs )
195-
196- # Run decoder.
197- with tf .Session () as sess :
198- sess .run (outputs , feed_dict = dict (zip (inputs , arrays )))
153+ with tf .io .gfile .GFile (input_file , "rb" ) as f :
154+ packed = tfc .PackedTensors (f .read ())
155+ receiver = instantiate_model_signature (packed .model , "receiver" )
156+ tensors = packed .unpack ([t .dtype for t in receiver .inputs ])
157+ output_image , = receiver (* tensors )
158+ write_png (output_file , output_image )
199159
200160
201161def list_models ():
0 commit comments