diff --git a/_examples/rss-reader.go b/_examples/rss-reader.go index 34ddbc1..9338c7f 100644 --- a/_examples/rss-reader.go +++ b/_examples/rss-reader.go @@ -23,7 +23,7 @@ func main() { slog.Debug("Successfully Parsed the Cli Args") if cli.GlobalConfig.DBExist() { - _, err = database.Init(database.CreateDBDns(cli.GlobalConfig.GetDBPath()), false) + _, err = database.Init(database.CreateDBDsn(cli.GlobalConfig.GetDBPath(), false), false) if err != nil { log.Fatalf("Unable to connect to DB at %s: %s", cli.GlobalConfig.GetDBPath(), err.Error()) } diff --git a/database/db.go b/database/db.go index 1799c36..9682bc7 100644 --- a/database/db.go +++ b/database/db.go @@ -21,11 +21,14 @@ func createTables(db *sql.DB) error { return nil } -func CreateDBDns(path string) string { +func CreateDBDsn(path string, inMemory bool) string { + if inMemory { + return fmt.Sprintf("file:%s?_foreign_keys=1&mode=memory", path) + } return fmt.Sprintf("file:%s?_foreign_keys=1", path) } -//Create -- Created the database +// Create -- Created the database func Create(path string) (*sql.DB, error) { //Create the database file _, err := os.Create(path) @@ -39,7 +42,7 @@ func Create(path string) (*sql.DB, error) { return db, err } -//Exist -- checks for the existance of a file +// Exist -- checks for the existance of a file func Exist(path string) bool { if _, err := os.Stat(path); err != nil { if os.IsNotExist(err) { @@ -49,10 +52,54 @@ func Exist(path string) bool { return true } -//Init -- Initializes the database. The reset param allows you to recreate the database. +type Connector interface { + Exec(string, ...any) (sql.Result, error) + Ping() error + QueryRow(query string, args ...any) *sql.Row + Close() error +} + +type Driver interface { + Open(string, string) (Connector, error) +} + +func createTablesV2(db Connector) error { + for _, sql := range sqlFiles { + _, err := db.Exec(sql) + if err != nil { + return fmt.Errorf("Creating tables failed with the following error: %w", err) + } + } + return nil +} + +// InitV2 -- Initializes the database. The reset param allows you to recreate the database. +func InitV2(driverName, dataSourceName string, reset bool, open func(string, string) (Connector, error)) (Connector, error) { + DB, err := open(driverName, dataSourceName) + if err != nil { + return DB, fmt.Errorf("Unable to Open Database: %w", err) + } + + err = DB.Ping() + if err != nil { + return DB, fmt.Errorf("Unable to ping the Database: %w", err) + } + + if reset { + //Drop all the tables and create all the tables again + err = createTablesV2(DB) + if err != nil { + return DB, err + } + } + + return DB, err +} + +// Init -- Initializes the database. The reset param allows you to recreate the database. func Init(dsn string, reset bool) (*sql.DB, error) { var err error - + //Prep the connection to the database DB, err = sql.Open(driver, dsn) if err != nil { @@ -76,8 +123,7 @@ func Init(dsn string, reset bool) (*sql.DB, error) { return DB, nil } - -//AddFeedFileData -- Adds Feed File Data to the database +// AddFeedFileData -- Adds Feed File Data to the database func AddFeedFileData(db *sql.DB, fileData []file.Data) (map[int64]file.Data, error) { var feedID int64 var tagID int64 diff --git a/database/db_test.go b/database/db_test.go index 0806c0c..17ff3ae 100644 --- a/database/db_test.go +++ b/database/db_test.go @@ -2,12 +2,161 @@ package database import ( "database/sql" + "errors" "fmt" "log" "os" + "strings" "testing" + + _ "github.com/mattn/go-sqlite3" //Sqlite3 driver ) +type connectorErr struct{} + +func (c connectorErr) Exec(query string, args ...any) (sql.Result, error) { + return nil, errors.New("Exec Error") +} + +func (c connectorErr) Ping() error { + return errors.New("Ping Error") +} + +func (c connectorErr) QueryRow(query string, args ...any) *sql.Row { + return nil +} + +func (c connectorErr) Close() error { + return errors.New("Connector failed to close") +} + +type driverOpenErr struct{} + +func (d driverOpenErr) Open(drivername, dsn string) (Connector, error) { + return nil, errors.New("Open Error") +} + +type driverPingErr struct{} + +func (d driverPingErr) Open(name, dsn string) (Connector, error) { + return connectorErr{}, nil +} + +func TestCreateTablesV2(t *testing.T) { + t.Parallel() + + tcs := []struct { + name string + conn Connector + err error + }{ + {"Exec error", connectorErr{}, errors.New("Creating tables failed with the following error: Exec Error")}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + err := createTablesV2(tc.conn) + // TODO: Figure out success test cases + if err == nil { + t.Fatalf("Expected err to be nil, but got %s", err.Error()) + } + if err != nil { + if err.Error() != tc.err.Error() { + t.Fatalf("Got %q, but expected %q", err.Error(), tc.err.Error()) + } + } + }) + + } + +} + +func TestInitV2HappyPath(t *testing.T) { + t.Parallel() + + driverName := "sqlite3" + + tcs := []struct { + name string + dsn string + reset bool + }{ + {"In Memory DB", "file:memory.db?_foreign_keys=1&mode=memory", false}, + {"In Memory DB Create Tables", "file:memory.db?_foreign_keys=1&mode=memory", true}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + _, err := InitV2(driverName, tc.dsn, false, func(a, b string) (Connector, error) { + conn, err := sql.Open(a, b) + if err != nil { + return conn, err + } + return conn, err + }) + if err != nil { + t.Fatalf("Exepected nil, but got %q", err.Error()) + } + + }) + } + +} + +func TestInitV2(t *testing.T) { + t.Parallel() + + tcs := []struct { + name string + driverName string + dsn string + reset bool + driver Driver + err error + }{ + {"Open Error", "driverName", "dsn", false, driverOpenErr{}, errors.New("Unable to Open Database: Open Error")}, + {"Ping Error", "driverName", "dsn", false, driverPingErr{}, errors.New("Unable to ping the Database: Ping Error")}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + _, err := InitV2(tc.driverName, tc.dsn, tc.reset, tc.driver.Open) + + if err == nil { + t.Fatalf("Expected err to be nil, but got %s", err.Error()) + } + if err != nil { + if err.Error() != tc.err.Error() { + t.Fatalf("Got %q, but expected %q", err.Error(), tc.err.Error()) + } + } + }) + } +} + +func TestCreateDBDsn(t *testing.T) { + t.Parallel() + + tcs := []struct { + name string + path string + inMemory bool + expected string + }{ + {"File Path string", "testing.db", false, "file:testing.db?_foreign_keys=1"}, + {"In Memory string", "memory.db", true, "file:memory.db?_foreign_keys=1&mode=memory"}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + result := CreateDBDsn(tc.path, tc.inMemory) + if strings.Compare(result, tc.expected) != 0 { + t.Fatalf("Expectd %q, but got %q", tc.expected, result) + } + }) + } +} + func createTestDB(file string) *sql.DB { testDB := fmt.Sprintf("file:%s%s", file, foreignKeySupport) diff --git a/database/feed.go b/database/feed.go index bc0d780..76a14c8 100644 --- a/database/feed.go +++ b/database/feed.go @@ -12,7 +12,7 @@ import ( "github.com/crazcalm/go-rss-reader/file" ) -//Feed -- Data structure used to hold a feed +// Feed -- Data structure used to hold a feed type Feed struct { ID int64 URL string @@ -21,7 +21,32 @@ type Feed struct { Data *gofeed.Feed } -//GetFeedDataFromSite -- gets the feed data from the feed url and returns it + +// GetFeedDataFromSiteV2 -- gets the feed data from the feed url and returns it +func GetFeedDataFromSiteV2(url string, reader func(io.Reader) ([]byte, error)) (string, error) { + resp, err := http.Get(url) + if err != nil { + return "", fmt.Errorf("Error trying to get the raw feed data from %s: %s", url, err.Error()) + } + defer func() { + if err = resp.Body.Close(); err != nil { + err = fmt.Errorf("Errorr occurred while closing the response body: %s", err.Error()) + } + }() + + if resp.StatusCode >= 300 { + return "", fmt.Errorf("url %q returned a status code of %v", url, resp.StatusCode) + } + + body, err := reader(resp.Body) + if err != nil { + return "", fmt.Errorf("Unable to read response body: %w", err) + } + return string(body), err +} + + +// GetFeedDataFromSite -- gets the feed data from the feed url and returns it func GetFeedDataFromSite(url string) (string, error) { resp, err := http.Get(url) if err != nil { @@ -40,8 +65,8 @@ func GetFeedDataFromSite(url string) (string, error) { return string(body), err } -//NewFeed -- Used to create a new Feed. Id the id is equal to -1, then -//all of the database interactions will not happen +// NewFeed -- Used to create a new Feed. Id the id is equal to -1, then +// all of the database interactions will not happen func NewFeed(id int64, fileData file.Data) (*Feed, error) { var db *sql.DB var err error diff --git a/database/feed_test.go b/database/feed_test.go index 3a65e92..5af267f 100644 --- a/database/feed_test.go +++ b/database/feed_test.go @@ -1,10 +1,67 @@ package database import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" "strings" "testing" ) +func TestGetFeedDataFromSiteV2(t *testing.T) { + t.Parallel() + + response_message := "hello world" + + good_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, response_message) + })) + + bad_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "TeaPot Error", 418) + })) + not_up_server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, %s", r.Proto) + })) + + t.Cleanup(func() { + bad_server.Close() + good_server.Close() + }) + + tcs := []struct { + name string + url string + reader func(io.Reader) ([]byte, error) + err error + }{ + {"Not Okay Status Code", bad_server.URL, io.ReadAll, errors.New("returned a status code of 418")}, + {"Read Response Error", good_server.URL, func(_ io.Reader) ([]byte, error) { + return nil, errors.New("Read Error") + }, errors.New("Unable to read response body: Read Error")}, + {"No Server Running", not_up_server.URL, io.ReadAll, errors.New("Error trying to get the raw feed data from : Get \"\": unsupported protocol scheme \"\"")}, + {"Happy Path", good_server.URL, io.ReadAll, nil}, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + result, err := GetFeedDataFromSiteV2(tc.url, tc.reader) + if err != nil { + if strings.Contains(err.Error(), tc.err.Error()) == false { + t.Fatalf("Expected %q, but got %q", tc.err.Error(), err.Error()) + } + } else { + if strings.Compare(result, response_message) != 0 { + t.Fatalf("Expected %q, but got %q", result, response_message) + } + } + }) + } + +} + func TestGetFeedDataFromSite(t *testing.T) { tests := []string{ "http://www.leoville.tv/podcasts/sn.xml", diff --git a/database/test_data/database/feeds.db b/database/test_data/database/feeds.db index 53ba71a..9dd1cfd 100644 Binary files a/database/test_data/database/feeds.db and b/database/test_data/database/feeds.db differ diff --git a/database/testing/init_test_file.db b/database/testing/init_test_file.db index 05c9424..4e74a85 100644 Binary files a/database/testing/init_test_file.db and b/database/testing/init_test_file.db differ