Augmented RealityMachine LearningUnity 3D

สร้าง Machine Learning ด้วย TensorFlow และ Unity ทำ Object Detection

บทเรียนการพัฒนาแอปพลิเคชันสำหรับ Object Detection ด้วย ML ยอดฮิตอย่าง TensorFlow

มาลองทำ Object Detection หรือการตรวจจับวัตถุว่าเป็นอะไร โดยให้ Machine Learning ได้ประมวลผลทำกับ Unity ร่วมกับ TensorFlow API for .NET หรือ TensorFlowSharp

API การตรวจจับวัตถุ หรือ Object Detection ของ TensorFlow นับว่าเป็นเครื่องมือที่ทรงพลังตัวหนึ่งที่ทุกคนสามารถเปิดใช้งานการประมวลผล Machie Learning ได้อย่างรวดเร็ว โดยเฉพาะผู้ที่ไม่มีพื้นฐานการเรียนรู้ของการทำงานด้าน AI และ Machine Learning เพื่อสร้างและปรับใช้ซอฟต์แวร์ของพวกประมวลผล หรือจดจำรูปภาพให้ทำงานได้อย่างรวดเร็ว และเป็นประโยชน์

บทเรียนนี้เราจำเป็นต้องใช้ Library Assets ของ Unity ตัวหนึ่งที่ชื่อว่า TensorFlowSharp

หรือแบบที่พร้อมใช้งานเลยถ้าเข้าใจแล้ว ดาวน์โหลด (สำหรับสายขี้เกียจ) เป็น Template Project นั้นทางผมได้เตรียมให้แล้วที่ Github ไปดาวน์โหลดมาได้เลย

เริ่มต้นพัฒนา

ทำการเปิด Project ของผมที่เราได้ git clone https://github.com/banyapon/TensorFowUnitySample ลงมาบน Unity หลังจากนั้นติดตั้ง ML Kit ของ TensorFlow ด้วยการ Import Assets ของ TensotFlowSharp Unity Package ที่ดาวน์โหลดมาที่เมนู

Assets->Import Package->Custom Package

รอจนกว่าระบบจะประมวลผลเสร็จก็เรียบร้อย

ขั้นตอนต่อมาเราจำเป็นจะต้องเปลี่ยนแพลตฟอร์มเป็น Android หลังจากนั้น ตั้งค่า Build Setting -> Player Setting โดย

ใน Other Setting ให้เราไปเพิ่ม ENABLE_TENSORFLOW ในช่อง Scripting Define Symbols

เมื่อเสร็จแล้วไปที่ Project -> Assets หา Folder ที่ชื่อว่า “ML-Agents” เลือก Plugins->Android เราจะเห็นไฟล์นามสกุล .dll มากมายปรากฏอยู่ให้ทำการลบ .dll ทุกไฟล์ เหลือไว้เพียง

  • Java.Interop
  • Mono.Android
  • System.Linq
  • TensorFlowSharp
  • TensorFlowSharp.Android

ทีนี้สังเกตการทำงานของ Unity ของเราคือ Object Detect เราจะมีการบังคับเปิดกล้องมือถือผ่าน C# ที่ชื่อว่า PhoneCamera.cs ไปวางใน MainCamera ซึ่งจะมี Mode ให้เราเลือกคือ Detector จะไปเรียก Class ของ Detector.cs อีกที

using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using TFClassify;
using System.Diagnostics;
using System.Threading.Tasks;
using Debug = UnityEngine.Debug;
using TensorFlow;


public enum Mode
{
    Detect,
    Classify,
}

public class PhoneCamera : MonoBehaviour
{
    private const int detectImageSize = 300;
    private const int classifyImageSize = 224;

    private static Texture2D boxOutlineTexture;
    private static GUIStyle labelStyle;
    
    private bool camAvailable;
    private WebCamTexture backCamera;
    private Texture defaultBackground;
    
    private Classifier classifier;
    private Detector detector;

    private List<BoxOutline> boxOutlines;
    private Vector2 backgroundSize;
    private Vector2 backgroundOrigin;


    public Mode mode;
    public RawImage background;
    public AspectRatioFitter fitter;
    public TextAsset modelFile;
    public TextAsset labelsFile;
    public Text uiText;
    

    private void Start()
    {
        LoadWorker();

        defaultBackground = background.texture;
        WebCamDevice[] devices = WebCamTexture.devices;

        if(devices.Length == 0)
        {
            this.uiText.text = "No camera detected";
            camAvailable = false;

            return;
        }

        for(int i = 0; i < devices.Length; i++)
        {
            if(!devices[i].isFrontFacing)
            {
                this.backCamera = new WebCamTexture(devices[i].name, Screen.width, Screen.height);
            }
        }

        if(backCamera == null)
        {
            this.uiText.text = "Unable to find back camera";
            
            return;
        }


        this.backCamera.Play();
        this.background.texture = this.backCamera;
        this.backgroundSize = new Vector2(this.backCamera.width, this.backCamera.height);
        camAvailable = true;

        string func = mode == Mode.Classify ? nameof(TFClassify) : nameof(TFDetect);
        InvokeRepeating(func, 1f, 1f);
    }


    private void Update()
    {
        if(!this.camAvailable)
        {
            return;
        }

        float ratio = (float)backCamera.width / (float)backCamera.height;
        fitter.aspectRatio = ratio;

        float scaleY = backCamera.videoVerticallyMirrored ? -1f : 1f;
        background.rectTransform.localScale = new Vector3(1f, scaleY, 1f);

        int orient = -backCamera.videoRotationAngle;
        background.rectTransform.localEulerAngles = new Vector3(0, 0, orient);
    }


    public void OnGUI()
    {
        if (this.boxOutlines != null && this.boxOutlines.Any())
        {
            foreach (var outline in this.boxOutlines)
            {
                DrawBoxOutline(outline);
            }
        }
    }


    private void LoadWorker()
    {
        try
        {
            if (mode == Mode.Classify)
            {
                LoadClassifier();
            }
            else
            {
                LoadDetector();
            }
        }
        catch (TFException ex)
        {
            if (ex.Message.EndsWith("is up to date with your GraphDef-generating binary.)."))
            {
                this.uiText.text = "Error: TFException. Make sure you model trained with same version of TensorFlow as in Unity plugin.";
            }
            
            throw;
        }
    }


    private void LoadClassifier()
    {
        this.classifier = new Classifier(
            this.modelFile.bytes,
            Regex.Split(this.labelsFile.text, "\n|\r|\r\n")
                .Where(s => !String.IsNullOrEmpty(s)).ToArray(),
            classifyImageSize);
    }


    private void LoadDetector()
    {
        this.detector = new Detector(
            this.modelFile.bytes,
            Regex.Split(this.labelsFile.text, "\n|\r|\r\n")
                .Where(s => !String.IsNullOrEmpty(s)).ToArray(),
            detectImageSize);
    }


    private async void TFClassify()
    {
        var snap = TakeTextureSnap();
        var scaled = Scale(snap, classifyImageSize);
        var rotated = await RotateAsync(scaled.GetPixels32(), scaled.width, scaled.height);

        try
        {
            var probabilities = await this.classifier.ClassifyAsync(rotated);
            this.uiText.text = String.Empty;

            for(int i = 0; i < 3; i++)
            {
                this.uiText.text += probabilities[i].Key + ": " + String.Format("{0:0.000}%", probabilities[i].Value) + "\n";
            }
        }
        catch (NullReferenceException)
        {
            this.uiText.text = "Error: NullReferenceException. Make sure you set correct INPUT_NAME and OUTPUT_NAME";
        }
        finally
        {
            Destroy(snap);
            Destroy(scaled);
        }
    }


    private async void TFDetect()
    {
        UpdateBackgroundOrigin();

        var snap = TakeTextureSnap();
        var scaled = Scale(snap, detectImageSize);
        var rotated = await RotateAsync(scaled.GetPixels32(), scaled.width, scaled.height);
        this.boxOutlines = await this.detector.DetectAsync(rotated);

        Destroy(snap);
        Destroy(scaled);
    }

    
    private void UpdateBackgroundOrigin()
    {
        var backgroundPosition = this.background.transform.position;
        this.backgroundOrigin = new Vector2(backgroundPosition.x - this.backgroundSize.x / 2, 
                                            backgroundPosition.y - this.backgroundSize.y / 2);
    }


    private void DrawBoxOutline(BoxOutline outline)
    {
        var xMin = outline.XMin * this.backgroundSize.x + this.backgroundOrigin.x;
        var xMax = outline.XMax * this.backgroundSize.x + this.backgroundOrigin.x;
        var yMin = outline.YMin * this.backgroundSize.y + this.backgroundOrigin.y;
        var yMax = outline.YMax * this.backgroundSize.y + this.backgroundOrigin.y;

        DrawRectangle(new Rect(xMin, yMin, xMax - xMin, yMax - yMin), 4, Color.green);
        DrawLabel(new Rect(xMin + 10, yMin + 10, 200, 20), $"{outline.Label}: {(int)(outline.Score * 100)}%");
    }


    public static void DrawRectangle(Rect area, int frameWidth, Color color)
    {
        // Create a one pixel texture with the right color
        if (boxOutlineTexture == null)
        {
            var texture = new Texture2D(1, 1);
            texture.SetPixel(0, 0, color);
            texture.Apply();
            boxOutlineTexture = texture;
        }
        
        Rect lineArea = area;
        lineArea.height = frameWidth;
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Top line

        lineArea.y = area.yMax - frameWidth; 
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Bottom line

        lineArea = area;
        lineArea.width = frameWidth;
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Left line

        lineArea.x = area.xMax - frameWidth;
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Right line
    }


    private static void DrawLabel(Rect position, string text)
    {
        if (labelStyle == null)
        {
            var style = new GUIStyle();
            style.fontSize = 50;
            style.normal.textColor = Color.red;
            labelStyle = style;
        }

        GUI.Label(position, text, labelStyle);
    }


    private Texture2D TakeTextureSnap()
    {
        var smallest = backCamera.width < backCamera.height ?
            backCamera.width : backCamera.height;
        var snap = TextureTools.CropWithRect(backCamera,
             new Rect(0, 0, smallest, smallest),
            TextureTools.RectOptions.Center, 0, 0);

        return snap;
    }


    private Texture2D Scale(Texture2D texture, int imageSize)
    {
        var scaled = TextureTools.scaled(texture, imageSize, imageSize, FilterMode.Bilinear);
        
        return scaled;
    }


    private Task<Color32[]> RotateAsync(Color32[] pixels, int width, int height)
    {
        return Task.Run(() =>
        {
            return TextureTools.RotateImageMatrix(
                    pixels, width, height, -90);
        });
    }

    private Task<Texture2D> RotateAsync(Texture2D texture)
    {
        return Task.Run(() =>
        {
            return TextureTools.RotateTexture(texture, -90);
        });
    }


    private void SaveToFile(Texture2D texture)
    {
        File.WriteAllBytes(
            Application.persistentDataPath + "/" +
            "snap.png", texture.EncodeToPNG());
    }
}

เราจะเรียก LoadDetector() และ TFClassify() จาก Library ร่วมกับคลาสในการส่งไฟล์รูปภาพจากกล้องมาประมวลผลเพื่อวาดกรอบข้อมูล พร้อม Label ใน

public static void DrawRectangle(Rect area, int frameWidth, Color color)
    {
        // Create a one pixel texture with the right color
        if (boxOutlineTexture == null)
        {
            var texture = new Texture2D(1, 1);
            texture.SetPixel(0, 0, color);
            texture.Apply();
            boxOutlineTexture = texture;
        }
        
        Rect lineArea = area;
        lineArea.height = frameWidth;
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Top line

        lineArea.y = area.yMax - frameWidth; 
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Bottom line

        lineArea = area;
        lineArea.width = frameWidth;
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Left line

        lineArea.x = area.xMax - frameWidth;
        GUI.DrawTexture(lineArea, boxOutlineTexture); // Right line
    }


    private static void DrawLabel(Rect position, string text)
    {
        if (labelStyle == null)
        {
            var style = new GUIStyle();
            style.fontSize = 50;
            style.normal.textColor = Color.red;
            labelStyle = style;
        }

        GUI.Label(position, text, labelStyle);
    }

ไป implement C# เพิ่มอีกไฟล์คือ Detector.cs ดังนี้:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using TensorFlow;
using UnityEngine;


namespace TFClassify
{
    public class BoxOutline
    {
        public float YMin { get; set; } = 0;
        public float XMin { get; set; } = 0;
        public float YMax { get; set; } = 0;
        public float XMax { get; set; } = 0;
        public string Label { get; set; }
        public float Score { get; set; }
    }

    public class Detector
    {
        private static int IMAGE_MEAN = 117;
        private static float IMAGE_STD = 1;
        
        // Minimum detection confidence to track a detection.
        private static float MINIMUM_CONFIDENCE = 0.6f;

        private int inputSize;
        private TFGraph graph;
        private string[] labels;

        public Detector(byte[] model, string[] labels, int inputSize)
        {
#if UNITY_ANDROID
            TensorFlowSharp.Android.NativeBinding.Init();
#endif
            this.labels = labels;
            this.inputSize = inputSize;
            this.graph = new TFGraph();
            this.graph.Import(new TFBuffer(model));
        }


        public Task<List<BoxOutline>> DetectAsync(Color32[] data)
        {
            return Task.Run(() =>
            {
                using (var session = new TFSession(this.graph))
                using (var tensor = TransformInput(data, this.inputSize, this.inputSize))
                {
                    var runner = session.GetRunner();
                    runner.AddInput(this.graph["image_tensor"][0], tensor)
                          .Fetch(this.graph["detection_boxes"][0],
                                 this.graph["detection_scores"][0],
                                 this.graph["detection_classes"][0],
                                 this.graph["num_detections"][0]);
                    var output = runner.Run();

                    var boxes = (float[,,])output[0].GetValue(jagged: false);
                    var scores = (float[,])output[1].GetValue(jagged: false);
                    var classes = (float[,])output[2].GetValue(jagged: false);
                        
                    foreach(var ts in output)
                    {
                        ts.Dispose();
                    }

                    return GetBoxes(boxes, scores, classes, MINIMUM_CONFIDENCE);
                }
            });
        }


        public static TFTensor TransformInput(Color32[] pic, int width, int height)
        {
            byte[] floatValues = new byte[width * height * 3];

            for (int i = 0; i < pic.Length; ++i)
            {
                var color = pic[i];

                floatValues [i * 3 + 0] = (byte)((color.r - IMAGE_MEAN) / IMAGE_STD);
                floatValues [i * 3 + 1] = (byte)((color.g - IMAGE_MEAN) / IMAGE_STD);
                floatValues [i * 3 + 2] = (byte)((color.b - IMAGE_MEAN) / IMAGE_STD);
            }

            TFShape shape = new TFShape(1, width, height, 3);

            return TFTensor.FromBuffer(shape, floatValues, 0, floatValues.Length);
        }


        private List<BoxOutline> GetBoxes(float[,,] boxes, float[,] scores, float[,] classes, double minScore)
        {
            var x = boxes.GetLength(0);
            var y = boxes.GetLength(1);
            var z = boxes.GetLength(2);

            float ymin = 0, xmin = 0, ymax = 0, xmax = 0;
            var results = new List<BoxOutline>();

            for (int i = 0; i < x; i++) 
            {
                for (int j = 0; j < y; j++) 
                {
                    if (scores [i, j] < minScore) continue;

                    for (int k = 0; k < z; k++) 
                    {
                        var box = boxes [i, j, k];
                        switch (k) {
                        case 0:
                            ymin = box;
                            break;
                        case 1:
                            xmin = box;
                            break;
                        case 2:
                            ymax = box;
                            break;
                        case 3:
                            xmax = box;
                            break;
                        }
                    }

                    int value = Convert.ToInt32(classes[i, j]);
                    var label = this.labels[value];
                    var boxOutline = new BoxOutline
                    {
                        YMin = ymin,
                        XMin = xmin,
                        YMax = ymax,
                        XMax = xmax,
                        Label = label,
                        Score = scores[i, j],
                    };

                    results.Add(boxOutline);
                }
            }

            return results;
        }
    }
}

เพิ่ม Class อีกตัวที่น่าสนใจคือ Classifier สร้าง C# ใหม่ขึ้นมาว่า Classifier.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using TensorFlow;
using UnityEngine;


namespace TFClassify
{
    public class Classifier
    {
        private static int IMAGE_MEAN = 117;
        private static float IMAGE_STD = 1;
        private static string INPUT_NAME = "input";
        private static string OUTPUT_NAME = "output";

        private int inputSize;
        private TFGraph graph;
        private string[] labels;

        
        public Classifier(byte[] model, string[] labels, int inputSize)
        {
#if UNITY_ANDROID
            TensorFlowSharp.Android.NativeBinding.Init();
#endif
            this.labels = labels;
            this.inputSize = inputSize;
            this.graph = new TFGraph();
            this.graph.Import(model, "");
        }


        public Task<List<KeyValuePair<string, float>>> ClassifyAsync(Color32[] data)
        {
            return Task.Run(() =>
            {
                var map = new List<KeyValuePair<string, float>>();

                using (var session = new TFSession(this.graph))
                using (var tensor = TransformInput(data, this.inputSize, this.inputSize))
                {
                    var runner = session.GetRunner();
                    runner.AddInput(graph[INPUT_NAME][0], tensor).Fetch(graph[OUTPUT_NAME][0]);
                    var output = runner.Run();
                    
                    // output[0].Value() is a vector containing probabilities of
                    // labels for each image in the "batch". The batch size was 1.
                    // Find the most probably label index.

                    var result = output[0];
                    var rshape = result.Shape;
                    
                    if (result.NumDims != 2 || rshape [0] != 1)
                    {
                        var shape = "";
                        foreach (var d in rshape)
                        {
                            shape += $"{d} ";
                        }
                        
                        shape = shape.Trim ();
                        Debug.Log("Error: expected to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape [{shape}]");
                        Environment.Exit (1);
                    }

                    var probabilities = ((float[][])result.GetValue(jagged: true))[0];

                    for (int i = 0; i < labels.Length; i++)
                    {
                        map.Add(new KeyValuePair<string, float>(labels[i], probabilities[i] * 100));
                    }

                    foreach (var ts in output)
                    {
                        ts.Dispose();
                    }
                }

                return map.OrderByDescending(x => x.Value).ToList();
            });
        }


        public static TFTensor TransformInput(Color32[] pic, int width, int height)
        {
            float[] floatValues = new float[width * height * 3];

            for (int i = 0; i < pic.Length; ++i)
            {
                var color = pic[i];

                floatValues [i * 3 + 0] = (color.r - IMAGE_MEAN) / IMAGE_STD;
                floatValues [i * 3 + 1] = (color.g - IMAGE_MEAN) / IMAGE_STD;
                floatValues [i * 3 + 2] = (color.b - IMAGE_MEAN) / IMAGE_STD;
            }

            TFShape shape = new TFShape(1, width, height, 3);

            return TFTensor.FromBuffer(shape, floatValues, 0, floatValues.Length);
        }
    }
}

เสียบสาย USB เข้ากับเครื่องคอมพิวเตอร์ของเราหลังจากนั้นให้ Build & run ตัว Android ของเราลงสมาร์ตโฟน เพื่อเริ่มต้นทดสอบ

จะเห็นว่าเราสามารถทำ Object Detector ง่ายๆ ด้วย TensorFlowSharp กับ Unity ได้แล้ว

สำหรับคนที่มี Model ของ TensorFlow เป็นของตัวเอง Model ที่เลือกมาต้องถูก Trained ด้วย TensorFlow 1.4 ขึ้นไปนะครับ เปลี่ยนนามสกุลไฟล์จาก .pb เป็น .bytes ด้วย

Asst. Prof. Banyapon Poolsawas

อาจารย์ประจำสาขาวิชาการออกแบบเชิงโต้ตอบ และการพัฒนาเกม วิทยาลัยครีเอทีฟดีไซน์ & เอ็นเตอร์เทนเมนต์เทคโนโลยี มหาวิทยาลัยธุรกิจบัณฑิตย์ ผู้ก่อตั้ง บริษัท Daydev Co., Ltd, (เดย์เดฟ จำกัด)

Related Articles

Back to top button

Adblock Detected

เราตรวจพบว่าคุณใช้ Adblock บนบราวเซอร์ของคุณ,กรุณาปิดระบบ Adblock ก่อนเข้าอ่าน Content ของเรานะครับ, ถือว่าช่วยเหลือกัน