Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _examples/rss-reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func main() {
slog.Debug("Successfully Parsed the Cli Args")

if cli.GlobalConfig.DBExist() {
_, err = database.Init(database.CreateDBDns(cli.GlobalConfig.GetDBPath()), false)
_, err = database.Init(database.CreateDBDsn(cli.GlobalConfig.GetDBPath(), false), false)
if err != nil {
log.Fatalf("Unable to connect to DB at %s: %s", cli.GlobalConfig.GetDBPath(), err.Error())
}
Expand Down
60 changes: 53 additions & 7 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ func createTables(db *sql.DB) error {
return nil
}

func CreateDBDns(path string) string {
func CreateDBDsn(path string, inMemory bool) string {
if inMemory {
return fmt.Sprintf("file:%s?_foreign_keys=1&mode=memory", path)
}
return fmt.Sprintf("file:%s?_foreign_keys=1", path)
}

//Create -- Created the database
// Create -- Created the database
func Create(path string) (*sql.DB, error) {
//Create the database file
_, err := os.Create(path)
Expand All @@ -39,7 +42,7 @@ func Create(path string) (*sql.DB, error) {
return db, err
}

//Exist -- checks for the existance of a file
// Exist -- checks for the existance of a file
func Exist(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
Expand All @@ -49,10 +52,54 @@ func Exist(path string) bool {
return true
}

//Init -- Initializes the database. The reset param allows you to recreate the database.
type Connector interface {
Exec(string, ...any) (sql.Result, error)
Ping() error
QueryRow(query string, args ...any) *sql.Row
Close() error
}

type Driver interface {
Open(string, string) (Connector, error)
}

func createTablesV2(db Connector) error {
for _, sql := range sqlFiles {
_, err := db.Exec(sql)
if err != nil {
return fmt.Errorf("Creating tables failed with the following error: %w", err)
}
}
return nil
}

// InitV2 -- Initializes the database. The reset param allows you to recreate the database.
func InitV2(driverName, dataSourceName string, reset bool, open func(string, string) (Connector, error)) (Connector, error) {
DB, err := open(driverName, dataSourceName)
if err != nil {
return DB, fmt.Errorf("Unable to Open Database: %w", err)
}

err = DB.Ping()
if err != nil {
return DB, fmt.Errorf("Unable to ping the Database: %w", err)
}

if reset {
//Drop all the tables and create all the tables again
err = createTablesV2(DB)
if err != nil {
return DB, err
}
}

return DB, err
}

// Init -- Initializes the database. The reset param allows you to recreate the database.
func Init(dsn string, reset bool) (*sql.DB, error) {
var err error

//Prep the connection to the database
DB, err = sql.Open(driver, dsn)
if err != nil {
Expand All @@ -76,8 +123,7 @@ func Init(dsn string, reset bool) (*sql.DB, error) {
return DB, nil
}


//AddFeedFileData -- Adds Feed File Data to the database
// AddFeedFileData -- Adds Feed File Data to the database
func AddFeedFileData(db *sql.DB, fileData []file.Data) (map[int64]file.Data, error) {
var feedID int64
var tagID int64
Expand Down
149 changes: 149 additions & 0 deletions database/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,161 @@ package database

import (
"database/sql"
"errors"
"fmt"
"log"
"os"
"strings"
"testing"

_ "github.com/mattn/go-sqlite3" //Sqlite3 driver
)

type connectorErr struct{}

func (c connectorErr) Exec(query string, args ...any) (sql.Result, error) {
return nil, errors.New("Exec Error")
}

func (c connectorErr) Ping() error {
return errors.New("Ping Error")
}

func (c connectorErr) QueryRow(query string, args ...any) *sql.Row {
return nil
}

func (c connectorErr) Close() error {
return errors.New("Connector failed to close")
}

type driverOpenErr struct{}

func (d driverOpenErr) Open(drivername, dsn string) (Connector, error) {
return nil, errors.New("Open Error")
}

type driverPingErr struct{}

func (d driverPingErr) Open(name, dsn string) (Connector, error) {
return connectorErr{}, nil
}

func TestCreateTablesV2(t *testing.T) {
t.Parallel()

tcs := []struct {
name string
conn Connector
err error
}{
{"Exec error", connectorErr{}, errors.New("Creating tables failed with the following error: Exec Error")},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := createTablesV2(tc.conn)
// TODO: Figure out success test cases
if err == nil {
t.Fatalf("Expected err to be nil, but got %s", err.Error())
}
if err != nil {
if err.Error() != tc.err.Error() {
t.Fatalf("Got %q, but expected %q", err.Error(), tc.err.Error())
}
}
})

}

}

func TestInitV2HappyPath(t *testing.T) {
t.Parallel()

driverName := "sqlite3"

tcs := []struct {
name string
dsn string
reset bool
}{
{"In Memory DB", "file:memory.db?_foreign_keys=1&mode=memory", false},
{"In Memory DB Create Tables", "file:memory.db?_foreign_keys=1&mode=memory", true},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
_, err := InitV2(driverName, tc.dsn, false, func(a, b string) (Connector, error) {
conn, err := sql.Open(a, b)
if err != nil {
return conn, err
}
return conn, err
})
if err != nil {
t.Fatalf("Exepected nil, but got %q", err.Error())
}

})
}

}

func TestInitV2(t *testing.T) {
t.Parallel()

tcs := []struct {
name string
driverName string
dsn string
reset bool
driver Driver
err error
}{
{"Open Error", "driverName", "dsn", false, driverOpenErr{}, errors.New("Unable to Open Database: Open Error")},
{"Ping Error", "driverName", "dsn", false, driverPingErr{}, errors.New("Unable to ping the Database: Ping Error")},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
_, err := InitV2(tc.driverName, tc.dsn, tc.reset, tc.driver.Open)

if err == nil {
t.Fatalf("Expected err to be nil, but got %s", err.Error())
}
if err != nil {
if err.Error() != tc.err.Error() {
t.Fatalf("Got %q, but expected %q", err.Error(), tc.err.Error())
}
}
})
}
}

func TestCreateDBDsn(t *testing.T) {
t.Parallel()

tcs := []struct {
name string
path string
inMemory bool
expected string
}{
{"File Path string", "testing.db", false, "file:testing.db?_foreign_keys=1"},
{"In Memory string", "memory.db", true, "file:memory.db?_foreign_keys=1&mode=memory"},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
result := CreateDBDsn(tc.path, tc.inMemory)
if strings.Compare(result, tc.expected) != 0 {
t.Fatalf("Expectd %q, but got %q", tc.expected, result)
}
})
}
}

func createTestDB(file string) *sql.DB {
testDB := fmt.Sprintf("file:%s%s", file, foreignKeySupport)

Expand Down
33 changes: 29 additions & 4 deletions database/feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/crazcalm/go-rss-reader/file"
)

//Feed -- Data structure used to hold a feed
// Feed -- Data structure used to hold a feed
type Feed struct {
ID int64
URL string
Expand All @@ -21,7 +21,32 @@ type Feed struct {
Data *gofeed.Feed
}

//GetFeedDataFromSite -- gets the feed data from the feed url and returns it

// GetFeedDataFromSiteV2 -- gets the feed data from the feed url and returns it
func GetFeedDataFromSiteV2(url string, reader func(io.Reader) ([]byte, error)) (string, error) {
resp, err := http.Get(url)
if err != nil {
return "", fmt.Errorf("Error trying to get the raw feed data from %s: %s", url, err.Error())
}
defer func() {
if err = resp.Body.Close(); err != nil {
err = fmt.Errorf("Errorr occurred while closing the response body: %s", err.Error())
}
}()

if resp.StatusCode >= 300 {
return "", fmt.Errorf("url %q returned a status code of %v", url, resp.StatusCode)
}

body, err := reader(resp.Body)
if err != nil {
return "", fmt.Errorf("Unable to read response body: %w", err)
}
return string(body), err
}


// GetFeedDataFromSite -- gets the feed data from the feed url and returns it
func GetFeedDataFromSite(url string) (string, error) {
resp, err := http.Get(url)
if err != nil {
Expand All @@ -40,8 +65,8 @@ func GetFeedDataFromSite(url string) (string, error) {
return string(body), err
}

//NewFeed -- Used to create a new Feed. Id the id is equal to -1, then
//all of the database interactions will not happen
// NewFeed -- Used to create a new Feed. Id the id is equal to -1, then
// all of the database interactions will not happen
func NewFeed(id int64, fileData file.Data) (*Feed, error) {
var db *sql.DB
var err error
Expand Down
57 changes: 57 additions & 0 deletions database/feed_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,67 @@
package database

import (
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestGetFeedDataFromSiteV2(t *testing.T) {
t.Parallel()

response_message := "hello world"

good_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, response_message)
}))

bad_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "TeaPot Error", 418)
}))
not_up_server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %s", r.Proto)
}))

t.Cleanup(func() {
bad_server.Close()
good_server.Close()
})

tcs := []struct {
name string
url string
reader func(io.Reader) ([]byte, error)
err error
}{
{"Not Okay Status Code", bad_server.URL, io.ReadAll, errors.New("returned a status code of 418")},
{"Read Response Error", good_server.URL, func(_ io.Reader) ([]byte, error) {
return nil, errors.New("Read Error")
}, errors.New("Unable to read response body: Read Error")},
{"No Server Running", not_up_server.URL, io.ReadAll, errors.New("Error trying to get the raw feed data from : Get \"\": unsupported protocol scheme \"\"")},
{"Happy Path", good_server.URL, io.ReadAll, nil},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
result, err := GetFeedDataFromSiteV2(tc.url, tc.reader)
if err != nil {
if strings.Contains(err.Error(), tc.err.Error()) == false {
t.Fatalf("Expected %q, but got %q", tc.err.Error(), err.Error())
}
} else {
if strings.Compare(result, response_message) != 0 {
t.Fatalf("Expected %q, but got %q", result, response_message)
}
}
})
}

}

func TestGetFeedDataFromSite(t *testing.T) {
tests := []string{
"http://www.leoville.tv/podcasts/sn.xml",
Expand Down
Binary file modified database/test_data/database/feeds.db
Binary file not shown.
Binary file modified database/testing/init_test_file.db
Binary file not shown.