diff --git a/test_tmdb.py b/test_tmdb.py index e3b4ac3..2b9f5d5 100644 --- a/test_tmdb.py +++ b/test_tmdb.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 import unittest -from tmdb.tmdb_api import Movie +from tmdb.tmdb_api import Movie, TvShow from tmdb.tmdb import TmdbBot from tmdb.database import Database from sqlalchemy import create_engine @@ -54,5 +54,33 @@ class TestTmdbMethods(unittest.TestCase): db.set_language('@testuser:example.com', 'en') self.assertEqual(str(db.get_language('@testuser:example.com')), 'en') + def test_id_lookup(self): + movie = Movie() + movie.query_details('2108') + self.assertEqual('The Breakfast Club', movie.title) + + def test_search_fails(self): + movie = Movie() + id = movie.search_title('Breakfast Club 2019') + self.assertEqual(id, None) + self.assertEqual(None, movie.title) + + # TV Shows + def test_search_tvshow(self): + movie = TvShow() + id = movie.search_title('The Flash') + self.assertEqual(id, 60735) + + def test_tv_title(self): + movie = TvShow() + movie.search_title('The Flash') + self.assertEqual('The Flash', movie.title) + + def test_cast(self): + movie = TvShow() + movie.search_title('The Flash') + self.assertEqual('Grant Gustin', movie.cast[0]) + self.assertEqual('Carlos Valdes', movie.cast[2]) + if __name__ == '__main__': unittest.main() diff --git a/tmdb/tmdb_api.py b/tmdb/tmdb_api.py index e9529b0..b2013b7 100644 --- a/tmdb/tmdb_api.py +++ b/tmdb/tmdb_api.py @@ -38,12 +38,11 @@ class TmdbApi(): def get_apikey(self): return { 'api_key' : self.api_key } - def request(self, request_uri, payload : dict = {}): - url = self.base_url + request_uri - _payload = self.get_apikey() - _payload['language'] = self.language - _payload.update(payload) - result = requests.get(url, params=_payload) + def request(self, request_uri, params : dict = {}): + url = self.base_url + request_uri.lstrip('/') + params.update(self.get_apikey()) + params.update({ 'language' : self.language }) + result = requests.get(url, params=params) self.valid = True return result @@ -60,15 +59,11 @@ class TmdbApi(): class Movie(TmdbApi): def __init__(self): super().__init__() - self.load_parameters() - pass def search_title(self, title): - url = self.base_url+ 'search/movie' - payload = self.get_apikey() - payload['language'] = self.language + payload = {} payload['query'] = title - result = requests.get(url, params=payload) + result = self.request('search/movie', params=payload) json = result.json() if json['total_results'] > 0: movie_id = json['results'][0]['id'] @@ -93,4 +88,37 @@ class Movie(TmdbApi): def get_cast(self, amount): return self.cast[:amount] + + +class TvShow(TmdbApi): + def __init__(self): + super().__init__() + def search_title(self, title): + payload = {} + payload['query'] = title + result = self.request('/search/tv', params=payload) + json = result.json() + if json['total_results'] > 0: + movie_id = json['results'][0]['id'] + self.query_details(movie_id) + return movie_id + + def query_details(self, id): + data = self.request('tv/' + str(id)).json() + self.title = data['name'] + self.id = data['id'] + self.poster_url = self.base_url_poster + data['poster_path'] + self.overview = data['overview'] + self.web_url = 'https://www.themoviedb.org/tv/' + str(self.id) + self.vote_average = str(data['vote_average']) + self.query_cast() + + def query_cast(self): + data = self.request('tv/'+str(self.id)+'/credits').json() + self.cast = [] + for actor in data['cast']: + self.cast.append(actor['name']) + + def get_cast(self, amount): + return self.cast[:amount] \ No newline at end of file