Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Before you begin, you'll of course need the Go programming language installed. Y

## Usage

`gococo -dir=<model folder> -jpg=<input.jpg> [-out=<output.jpg>] [-labels=<labels.txt>]`
`gococo -dir=<model folder> -jpg=<input-or-url.jpg> [-out=<output.jpg>] [-labels=<labels.txt>]`

## Using Pre-Trained Models with TensorFlow in Go

Expand Down Expand Up @@ -65,7 +65,7 @@ Unfortunately none of these questions have easy or well-documented answers!

### Step 1: Identify the input and output nodes of the graph

For my [GopherCon demo](https://www.youtube.com/watch?v=oiorteQg9n0), I joked that I looked through the protocol buffer directly to find the names of the TensorFlow nodes for my model, and that if you were smart, you might print out the names from Python before you exported it, or dump them to a file on disk.
For my [GopherCon demo](https://www.youtube.com/watch?v=oiorteQg9n0), I joked that I looked through the protocol buffer directly to find the names of the TensorFlow nodes for my model, and that if you were smart, you might print out the names from Python before you exported it, or dump them to a file on disk.

Shockingly, this is still one possibly effective strategy here. Combined with some Googling around the subject, digging through source code, you will find that the nodes for this model are as follows:

Expand All @@ -75,7 +75,7 @@ Shockingly, this is still one possibly effective strategy here. Combined with so
| detection_boxes | Output | [?][4] | Array of boxes for each detected object in the format [yMin, xMin, yMax, xMax] |
| detection_scores | Output | [?] | Array of probability scores for each detected object between 0..1 |
| detection_classes | Output | [?] | Array of object class indices for each object detected based on COCO objects |
| num_detections | Output | [1] | Number of detections |
| num_detections | Output | [1] | Number of detections |

I would suggest that it would be best practice when publishing models to include this information as part of the documentation. Once you get the names of the nodes associated with the input/output, you can use the `Shape` method to display the shape of these inputs. In our case the input shape is similar to the one used in the Inception example referred to earlier.

Expand Down Expand Up @@ -131,7 +131,7 @@ There are a couple of notes to keep in mind when parsing the results:

### Step 4: Visualizing the output

Printing out a list of probabilities, class indices and box dimensions isn't really that interesting so we'll also extend our little CLI to output a version of the image with the results of the model rendered into it. Like many of the existing examples, we'll draw bounding boxes and label them with the confidence percentage.
Printing out a list of probabilities, class indices and box dimensions isn't really that interesting so we'll also extend our little CLI to output a version of the image with the results of the model rendered into it. Like many of the existing examples, we'll draw bounding boxes and label them with the confidence percentage.

We're just going to use the built-in `image` package to do some basic rendering, as well as the built-in font for rendering the labels.

Expand Down
Binary file added bin/macOS/gococo
Binary file not shown.
204 changes: 142 additions & 62 deletions gococo.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ package main
import (
"bufio"
"bytes"
"encoding/json"
"flag"
"fmt"
"image"
Expand All @@ -35,8 +36,11 @@ import (
"image/jpeg"
"io/ioutil"
"log"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"

"golang.org/x/image/colornames"

Expand Down Expand Up @@ -78,43 +82,56 @@ func Rect(img *image.RGBA, x1, y1, x2, y2, width int, col color.Color) {
}

// TENSOR UTILITY FUNCTIONS
func makeTensorFromImage(filename string) (*tf.Tensor, image.Image, error) {
b, err := ioutil.ReadFile(filename)
func makeTensorFromImage(filenameOrUrl string) (tensorImage, error) {
var b []byte
var err error
if strings.Index(filenameOrUrl, "http") == 0 {
resp, err := http.Get(filenameOrUrl)
if err != nil {
return tensorImage{}, err
}
defer resp.Body.Close()
b, err = ioutil.ReadAll(resp.Body)
} else {
b, err = ioutil.ReadFile(filenameOrUrl)
}

if err != nil {
return nil, nil, err
return tensorImage{}, err
}

r := bytes.NewReader(b)
img, _, err := image.Decode(r)

if err != nil {
return nil, nil, err
return tensorImage{}, err
}

// DecodeJpeg uses a scalar String-valued tensor as input.
tensor, err := tf.NewTensor(string(b))
if err != nil {
return nil, nil, err
return tensorImage{}, err
}
// Creates a tensorflow graph to decode the jpeg image
graph, input, output, err := decodeJpegGraph()
if err != nil {
return nil, nil, err
return tensorImage{}, err
}
// Execute that graph to decode this one image
session, err := tf.NewSession(graph, nil)
if err != nil {
return nil, nil, err
return tensorImage{}, err
}
defer session.Close()
normalized, err := session.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
nil,
)
if err != nil {
return nil, nil, err
return tensorImage{}, err
}
return normalized[0], img, nil

return tensorImage{Tensor: normalized[0], Image: &img, Input: filenameOrUrl}, nil
}

func decodeJpegGraph() (graph *tf.Graph, input, output tf.Output, err error) {
Expand Down Expand Up @@ -166,52 +183,19 @@ func addLabel(img *image.RGBA, x, y, class int, label string) {
d.DrawString(label)
}

func main() {
// Parse flags
modeldir := flag.String("dir", "", "Directory containing COCO trained model files. Assumes model file is called frozen_inference_graph.pb")
jpgfile := flag.String("jpg", "", "Path of a JPG image to use for input")
outjpg := flag.String("out", "output.jpg", "Path of output JPG for displaying labels. Default is output.jpg")
labelfile := flag.String("labels", "labels.txt", "Path to file of COCO labels, one per line")
flag.Parse()
if *modeldir == "" || *jpgfile == "" {
flag.Usage()
return
}

// Load the labels
loadLabels(*labelfile)

// Load a frozen graph to use for queries

modelpath := filepath.Join(*modeldir, "frozen_inference_graph.pb")
model, err := ioutil.ReadFile(modelpath)
if err != nil {
log.Fatal(err)
}

// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
log.Fatal(err)
}

// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
log.Fatal(err)
}
defer session.Close()
type tensorImage struct {
Tensor *tf.Tensor
Image *image.Image
Input string
}

// DecodeJpeg uses a scalar String-valued tensor as input.
tensor, i, err := makeTensorFromImage(*jpgfile)
if err != nil {
log.Fatal(err)
}
func predictImage(timg tensorImage, outputDir string, outputCount int, textOnly bool, graph *tf.Graph, session *tf.Session) (string, error) {
var buffer bytes.Buffer

// Transform the decoded YCbCr JPG image into RGBA
b := i.Bounds()
b := (*timg.Image).Bounds()
img := image.NewRGBA(b)
draw.Draw(img, b, i, b.Min, draw.Src)
draw.Draw(img, b, (*timg.Image), b.Min, draw.Src)

// Get all the input and output operations
inputop := graph.Operation("image_tensor")
Expand All @@ -224,17 +208,18 @@ func main() {
// Execute COCO Graph
output, err := session.Run(
map[tf.Output]*tf.Tensor{
inputop.Output(0): tensor,
inputop.Output(0): timg.Tensor,
},
[]tf.Output{
o1.Output(0),
o2.Output(0),
o3.Output(0),
o4.Output(0),
},
nil)
nil,
)
if err != nil {
log.Fatal(err)
return "", err
}

// Outputs
Expand All @@ -252,24 +237,119 @@ func main() {
y1 := float32(img.Bounds().Max.Y) * boxes[curObj][0]
y2 := float32(img.Bounds().Max.Y) * boxes[curObj][2]

Rect(img, int(x1), int(y1), int(x2), int(y2), 4, colornames.Map[colornames.Names[int(classes[curObj])]])
addLabel(img, int(x1), int(y1), int(classes[curObj]), getLabel(curObj, probabilities, classes))
labelString := getLabel(curObj, probabilities, classes)
if textOnly {
buffer.WriteString(labelString)
buffer.WriteString("\n")
} else {
Rect(img, int(x1), int(y1), int(x2), int(y2), 4, colornames.Map[colornames.Names[int(classes[curObj])]])
addLabel(img, int(x1), int(y1), int(classes[curObj]), labelString)
}

curObj++
}

// Output JPG file
outfile, err := os.Create(*outjpg)
if !textOnly {
outfile, err := os.Create(outputDir + "/output-" + strconv.Itoa(outputCount) + ".jpg")
if err != nil {
return "", err
}

var opt jpeg.Options
opt.Quality = 80

err = jpeg.Encode(outfile, img, &opt)
if err != nil {
return "", err
}
}

return buffer.String(), nil
}

func main() {
results := map[string]string{}

// Parse flags
modeldir := flag.String("dir", "", "Directory containing COCO trained model files. Assumes model file is called frozen_inference_graph.pb")
textonly := flag.Bool("textonly", false, "Output text labels instead of the image")
outdir := flag.String("outdir", "~", "Path of output directory to put JPG in for displaying labels.")
labelfile := flag.String("labels", "labels.txt", "Path to file of COCO labels, one per line")
flag.Parse()
if *modeldir == "" {
flag.Usage()
return
}

jpgFiles := flag.Args()
if len(jpgFiles) < 1 {
fmt.Println("Please specify one or more jpg urls/files to use as inputs")
return
}

// Load the labels
loadLabels(*labelfile)

// Load a frozen graph to use for queries
modelpath := filepath.Join(*modeldir, "frozen_inference_graph.pb")
model, err := ioutil.ReadFile(modelpath)
if err != nil {
log.Fatal(err)
}

// Construct an in-memory graph from the serialized form.
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
log.Fatal(err)
}

// Create a session for inference over graph.
session, err := tf.NewSession(graph, nil)
if err != nil {
log.Fatal(err)
}
defer session.Close()

// load all images at once
tensorc, errc := make(chan tensorImage), make(chan error)
for _, jpgFile := range jpgFiles {
go func(jpf string) {
timg, err := makeTensorFromImage(jpf)
if err != nil {
errc <- err
return
}
tensorc <- timg
}(jpgFile)
}

var opt jpeg.Options
tensorImgs := map[string]tensorImage{}

opt.Quality = 80
for i := 0; i < len(jpgFiles); i++ {
select {
case timg := <-tensorc:
tensorImgs[timg.Input] = timg
case err := <-errc:
fmt.Errorf("ERROR %v\n", err)
}
}

// run sessions in serial
i := 0
for _, timg := range tensorImgs {
prediction, err := predictImage(timg, *outdir, i, *textonly, graph, session)
if err != nil {
log.Fatal(err)
}
results[timg.Input] = prediction
i++
}

err = jpeg.Encode(outfile, img, &opt)
// print output in json
output, err := json.Marshal(results)
if err != nil {
log.Fatal(err)
}
fmt.Println(string(output))
}